In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

In [104]:
# get some data
import requests
import os
import tarfile
import re
import xml.etree.ElementTree as ET
import collections
import math

def maybe_download(data_path='reuters21578.tar.gz'):
    """might download Reuters-21578"""
    if not os.path.exists(data_path):
        print('Downloading dataset :)')
        url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz'
        r = requests.get(url, stream=True)
        with open(data_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk: # could be keep-alive chunks
                    f.write(chunk)
    # quickly validate the data
    size = os.path.getsize(data_path) >> 20
    if size != 7:  # wrong size :(
        raise ValueError('data file is wrong size ({}).'.format(size))
    return data_path

def read_reuters_file(filename, vocab_freqs, token_func):
    """Reads in a whole file & finds the necessary bits
    Returns a list of [text, label, set, id] where set is one of {train, test, valid}.
    according to the "ModHayes" split. Adds to counts in vocab_freqs, tokenizing using the 
    token_func function.
    """
    # the files have some things that make the xml parser unhappy
    # there are probably efficient ways to do this, if so this is
    # not one of them
    with open(filename, errors='ignore') as raw_file:
        file_str = raw_file.read()
        file_str = re.sub('&#\d{1,2};', '', file_str)
        # the docs don't have any kind of root tags
        # so we skip the doctype and wrap them in one
        file_str = '<root>' + file_str[file_str.find('\n'):-1] + '</root>'
        print('..parsing {}'.format(filename))
        root = ET.fromstring(file_str)
        data = []
        for child in root:
            if child.attrib['TOPICS'] == 'YES':  # we need to be able to evaluate
                try:
                    text = child.find('./TEXT/BODY').text
                except AttributeError:
                    text = child.find('./TEXT').text  # should check type=brief
                text = token_func(text)
                for symbol in text:
                    vocab_freqs[symbol] += 1
                topics = [d.text for d in child.findall('./TOPICS/D')]
                data.append(
                    [text, topics, 
                    'train' if child.attrib['CGISPLIT'] == 'TRAINING-SET' else 'test',
                    child.attrib['NEWID']])
        return data
    
def word_split(text):
    """roughly tokenise into words. should do some stemming or something"""
    # replace numbers with a special token
    text = text.casefold()
    text = re.sub('(\d+([.,]?))+', ' <NUMBER> ', text)
    text = re.sub('[.?!]', ' <STOP> ', text)
    # remove remaining punctuation
    text = re.sub(r'[^\w\s<>]', ' <PUNCT> ', text)
    return text.split()

def char_split(text):
    """just spit it out as a list of characters"""
    return list(text)

def report_statistics(data):
    """data is a sequence of sequences"""
    cumulative = 0
    longest = -1
    shortest = 1000000
    for item in data:
        length = len(item)
        cumulative += length
        if length > longest:
            longest = length
        if length < shortest:
            shortest = length
    mean = cumulative/len(data)
    cumulative = 0
    for item in data:
        cumulative += (mean - len(item))**2
    stddev = math.sqrt(cumulative / len(data))
    print('..mean: {} (stddev: {})'.format(mean, stddev))
    print('..max: {}'.format(longest))
    print('..min: {}'.format(shortest))

def get_reuters(data_dir='data', level='word', min_reps=5):
    """gets the reuters dataset as (training, test, vocab).
    training and test are sequences of sequences of ints (we'll tensor them up
    when we iterate them later) and vocab is a list of strings.
    """
    if not os.path.exists(data_dir):
        with tarfile.open(maybe_download(), 'r:gz') as datafile:
            print('Extracting archive')
            datafile.extractall(path=data_dir)
    # the data is probably small enough that we will just hold it in memory
    filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if re.search('.sgm$', f)]
    all_data = []
    vocab_freqs = collections.Counter()
    split_func = word_split if level == 'word' else char_split
    for filename in filenames:
        start = len(all_data)
        all_data.extend(read_reuters_file(filename, vocab_freqs, split_func))
        end = len(all_data)
        #print('...{} new items'.format(end-start))
    print('got {} in total'.format(len(all_data)))
    print('vocab: {}'.format(len(vocab_freqs)))
    print('Top 10: ')
    ordered_words = vocab_freqs.most_common()
    for word, count in ordered_words[0:10]:
        print('  {} ({})'.format(word, count))
    print('Bottom 10: ')
    for word, count in ordered_words[-1:-10:-1]:
        print('  {} ({})'.format(word, count))

    print('removing least frequent (<= {} repetitions)'.format(min_reps))
    to_remove = set()
    for word, count in ordered_words[-1:0:-1]:
        # loop through backwards
        if count > min_reps:
            break  # done
        to_remove.add(word)
        del vocab_freqs[word]
    print('..{} symbols to remove'.format(len(to_remove)))
    print('..new vocab size {}'.format(len(vocab_freqs)))
    for item in all_data:
        # ditch the ones we don't want
        item[0] = ['<UNK>' if i in to_remove else i for i in item[0]]
    
    print('Converting to id sequences...')
    # now we need to map the vocab to integers
    # map the specials by hand
    symbol_to_id = {'<GO>': 0, '<UNK>': 1, '<STOP>': 3, '<PUNCT>': 4, '<PAD>': 5}
    id_num = 1  # save 0 for the GO symbol
    for symbol in vocab_freqs:
        symbol_to_id[symbol] = id_num
        id_num += 1
    for item in all_data:
        item[0] = [symbol_to_id[i] for i in item[0]]
    #for file in all_data[0:5]:
     #   print(file)
    # collect training sequences
    training = [item[0] for item in all_data if item[2] == 'train']
    
    # let's have a look at some stats
    report_statistics(training)
    
    return (training, None, symbol_to_id)

In [105]:
training, _, vocab = get_reuters(level='word', min_reps=50)

..parsing data/reut2-000.sgm
..parsing data/reut2-001.sgm
..parsing data/reut2-002.sgm
..parsing data/reut2-003.sgm
..parsing data/reut2-004.sgm
..parsing data/reut2-005.sgm
..parsing data/reut2-006.sgm
..parsing data/reut2-007.sgm
..parsing data/reut2-008.sgm
..parsing data/reut2-009.sgm
..parsing data/reut2-010.sgm
..parsing data/reut2-011.sgm
..parsing data/reut2-012.sgm
..parsing data/reut2-013.sgm
..parsing data/reut2-014.sgm
..parsing data/reut2-015.sgm
..parsing data/reut2-016.sgm
..parsing data/reut2-017.sgm
..parsing data/reut2-018.sgm
..parsing data/reut2-019.sgm
..parsing data/reut2-020.sgm
..parsing data/reut2-021.sgm
got 13476 in total
vocab: 30639
Top 10: 
  <PUNCT> (136070)
  <NUMBER> (109129)
  the (91914)
  <STOP> (82619)
  of (47890)
  to (44521)
  in (35331)
  and (34636)
  a (33162)
  said (33096)
Bottom 10: 
  reintensification (1)
  betweenm (1)
  robustas (1)
  torino> (1)
  elaboration (1)
  disution (1)
  overvaluing (1)
  cos> (1)
  perconal (1)
removing least

In [133]:
# now let's see about making a seq2seq model to learn this biz
batch_size = 128
num_layers = 1
num_units = 2
def get_model(inputs, first=True):
    # we should use buckets to efficiently handle variable length
    # (it is highly variable)
    GO = tf.constant([vocab['<GO>']]*batch_size, dtype=tf.int32)
    loss_weights = tf.constant([1/batch_size] * batch_size)
    with tf.variable_scope('rnn') as scope:
        if not first:
            scope.reuse_variables()
        enc_inputs = [tf.placeholder(
                        tf.int32, shape=[None],name="encoder{}".format(i))
                      for i in range(len(inputs))]
        
        outputs, state = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
            enc_inputs, 
            [GO]*len(inputs),  # decoder inputs
            tf.nn.rnn_cell.MultiRNNCell(
                [tf.nn.rnn_cell.LSTMCell(
                    num_units,
                    num_units,  # number of inputs (embedding handles the rest)
                    )]*num_layers),
            len(vocab),
            output_projection=None,  # probably should be some
            dtype=tf.float32,
            feed_previous=True)
#         with tf.variable_scope('output_proj',
#                                initializer=tf.random_normal_initializer()):
#             softmax_w = tf.get_variable('W', [len(vocab), num_units])
#             softmax_b = tf.get_variable('b', [len(vocab)])
#             logits = [tf.matmul(softmax_w, output, transpose_b=True) + b for output in outputs]
        # TODO(pfcm): calculate weights depending on padding
        loss = tf.nn.seq2seq.sequence_loss(outputs,
                                           inputs, # goal is to reproduce
                                           [loss_weights]*len(inputs))
        # now optimise
        lr = tf.get_variable('learning_rate', [1], tf.float32, trainable=False)
        mo = tf.get_variable('momentum', [1], tf.float32, trainable=False)
        gn = tf.get_variable('grad_norm', [1], tf.float32, trainable=False)
        optimiser = tf.train.MomentumOptimizer(lr, mo)
        params = tf.trainable_variables()
        grads = tf.gradients(loss, params)
        clipped_grads, norm = tf.clip_by_global_norm(grads, gn)
        train_op = optimiser.apply_gradients(
            zip(clipped_grads, params))
        save = tf.train.Saver(tf.all_variables())
        return enc_inputs, loss, train_op, (lr, mo, gn)

In [134]:
import itertools
def batch_iterator(data, batch_size):
    """generator to iterate batches"""
    idces = np.arange(len(data))
    np.random.shuffle(idces)
    def partition(itr, n):
        r = iter(itr)
        res = None
        while True:
            res = list(itertools.islice(itr, 0, n))
            if res == []:
                break
            yield res
    
    # iterate batch_size chunks of idces
    for idx_chunk in partition(idces, batch_size):
        # grab the sequences and batch em up
        batch = []
        batch_list = [data[i] for i in idx_chunk]
        # what is the longest item in the batch?
        longest = max((len(item) for item in batch_list))
        for i in range(longest):
            batch.append(np.array(
                [item[i] if i < len(item) else vocab['<PAD>'] for item in batch_list]))
        yield batch


In [None]:
# clear the graph for safe re-running of this cell
tf.reset_default_graph()
# start a session
learning_rate = 0.1
momentum = 0.85
gradnorm = 5
with tf.Session() as sess:
    num_epochs = 10
    for epoch in range(num_epochs):
        loss = 0
        num_batches = 0
        for batch in batch_iterator(training, batch_size):
            # unrolls per batch, probably silly
            # although we will see once it actually does a couple of batches
            in_vars, loss_op, train_op, hps = get_model(batch,
                                                   epoch==0 and num_batches==0)
            feed = {zip(in_vars, batch)}
            feed[hps[0]] = learning_rate
            feed[hps[1]] = momentum
            feed[hps[2]] = gradnorm
            batch_loss, _ = sess.run([loss_op, train_op],
                                     {zip(in_vars, batch)})
            loss += batch_loss
            num_batches += 1
            print('.', end='', flush=True)
        print('\nEpoch {}, loss {}'.format(epoch+1, loss))