In [1]:
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
import numpy as np
import codecs
import time
import logging
import sys
import re
import string
import os
from collections import Counter

In [2]:
def _preprocess(line):
    '''
    performs the following preprocessing on a line of text.
    - lower case the text
    - remove punctuations
    - try alternate decoding if utf-8 fails
    - encode string as utf-8
    Note this is a pyfunc as tensorflow doesn't seem to have string operations beyond splitting
    
    Returns: a utf-8 encoded processed string
    '''
    try:
        line = line.decode('utf-8')
        line = line.lower()
        line = line.strip()
        line = re.sub('['+string.punctuation+']', '', line)
    except:
        try:
            line = line.decode('iso-8859-1')
            line = line.lower()
            line = re.sub('['+string.punctuation+']', '', line)
        except:
            return line.lower()
    return line.encode('utf-8')

def _add_vocab_file_generator(text_files):
    '''
    generates a vocabulary file processing block to the current graph. It uses TF 1.2 new Dataset API
    to process a list of files and split it in to tokens required for vocabulary generation
    
    Arguments: 1-D tensor of file names
    Returns: A iterator that returns a list of tokens per line.
    '''
    with tf.name_scope('vocab_gen') as scope:
        # create a dataset from list of file names
        dataset = tf.contrib.data.Dataset.from_tensor_slices(tf.constant(text_files))
        # generate a list of lines from files
        dataset = dataset.flat_map(lambda filename: tf.contrib.data.TextLineDataset(filename))
        # preprocess each line to lower case, remove punctuation chars
        dataset = dataset.map(lambda line: tf.py_func(_preprocess,[line], [tf.string]))
        # basic tokenization - space
        dataset = dataset.map(lambda line: tf.string_split([line]).values)
        # make a one shot iterator
        iterator = dataset.make_one_shot_iterator()
        next_line = iterator.get_next()
        return next_line


def gen_vocab_file(file_names, vocabulary_file='./vocab.txt', max_vocab_size=None,markers = True, reset_graph=False):
    '''
    generates the vocabulary file specified by vocabulary file by iterating through the iterator returned by
    _add_vocab_file_generator.
    
    Arguments:
    file_names - a list of filenames to process
    vocabulary_file - the name of output file
    max_vocab_size - the maximum number of words in the vocabulary. None implies all words will be included, else
                     only the max_vocab_size common words will be included
    markers - add the mandatory <UNK>, <SOS>, <EOS> markers to the beginning of vocab file
    reset_graph - reset current graph before running this function.
    
    Note: Gives some issues on Windows with special characters in file names. Otherwise tested to work well on corpus with
    a million files and up to 4 billion tokens
    '''
    if reset_graph is True:
        tf.reset_default_graph()
    
    next_line = _add_vocab_file_generator(file_names)
    
    with tf.Session() as sess:
        vocab = Counter()
        logging.info('Start generating %s from %d files' % (vocabulary_file, len(file_names)))
        start_time = time.time()
        line_count = 0
        word_count = 0
        while True:
            try:
                word_list=next_line.eval()
                vocab.update(word_list)
                line_count +=1
                word_count += len(word_list)
                if line_count % 100000==0:
                    logging.debug("%d lines and %d words processed" % (line_count, word_count))
            except tf.errors.OutOfRangeError:
                logging.debug("Completed:%d lines and %d words processed" % (line_count, word_count))
                break
        vocab = vocab.most_common(max_vocab_size)
        with  codecs.getreader("utf-8")(tf.gfile.GFile(vocabulary_file, "w")) as vocab_file:
            if markers is True:
                vocab_file.write('<PAD>\n<UNK>\n<SOS>\n<EOS>\n')
            for item in vocab:
                vocab_file.write("{}\n".format(item[0].decode('utf-8')))
        logging.info('Completed generating %s in %d s' % (vocabulary_file, time.time()-start_time))
        sess.close()

In [3]:
# generate vocabulary file from wiki data
gen_vocab_file(['./textwiki8'],reset_graph=True)

In [4]:
def create_padded_sequences(batch_size,text_files,vocab_file='./vocab.txt',
                            reset_graph=False):
    '''
    generate a padded sequence of data from source text files. Each line of the file
    will be converted to a padded sequence of integers coded by vocabulary and length
    of the sentence i.e
    'Hello World there','Bye there' -> ([12, 10, 15],3),([21, 15, 0],2)
    Arguments:
    ----------
    batch_size - size of batch requested
    text_files - list of file names of text files containing data
    vocab_file - path to vocabulary file containing the vocabulary for translating words
                 to ids. This file must have the format
                 <PAD>
                 <UNK>
                 <SOS>
                 <EOS>
                 word1
                 word2
                 ....
    reset_graph - [optional], reset the graph
    
    Returns:
    --------
    iterator - to generate a batch of data (of size batch size) of 
               (sentence ids, sentence length tuples) i.e
               ([[12,10,15],[21,15,0]], [3,2])
    vocab_table - tensorflow hashtable containing word to index mapping
    reverse_vocab_table - tensorflow reverse hashtable containing index to word mapping
    
    '''
    if reset_graph is True:
        tf.reset_default_graph()
    table = lookup_ops.index_table_from_file(vocab_file, num_oov_buckets=0,
                                             default_value=1)
    reverse_table = lookup_ops.index_to_string_table_from_file(vocab_file)
    dataset = tf.contrib.data.Dataset.from_tensor_slices(tf.constant(text_files))
    dataset = dataset.shuffle(buffer_size=100)
    dataset = dataset.flat_map(lambda filename: tf.contrib.data.TextLineDataset(filename))
    dataset = dataset = dataset.map(lambda line: tf.py_func(_preprocess,[line],
                                                            [tf.string]))
    dataset = dataset.map(lambda line: tf.string_split([line]).values)
    dataset = dataset.map(lambda words: (table.lookup(words), tf.size(words)))
    pad_value = tf.cast(table.lookup(tf.constant('<PAD>')),dtype=tf.int64)
    dataset = dataset.padded_batch(batch_size,padded_shapes=(tf.TensorShape([None]),
                                                             tf.TensorShape([])),
                                   padding_values=(pad_value,0))
    iterator = dataset.make_initializable_iterator()
    return iterator, table, reverse_table

In [5]:
#Generate padded sequence for one file with batch size of 5
iterator, table, reverse_table = create_padded_sequences(5,['./test-text.txt'])

In [6]:
sess = tf.InteractiveSession()
sess.run(tf.tables_initializer())
sess.run(iterator.initializer)

In [7]:
sent, sent_len = sess.run(iterator.get_next())
print(sent)
print(sent_len)

[[  4929   3087     15      9    185      5   2843     51     61    160
     130    781    461  10180    137      4  25285      5      4    116
     893      6      4  16153  53813      5      4    165    893]
 [  1673     87    123      8      7     17     26     24      4    187
   69474     18    745      0      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0]
 [   242     88      4     51    187     92     50   3248    236   2140
       0      0      0      0      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0]
 [  6491    137   4929    232     10     48   6659    164     10   1013
    1865      0      0      0      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0]
 [  1893      4    187    145 113639     31      4      0      0      0
       0      0      0      0      0      0      0      0      0      0
       0      0 

In [8]:
#verify reverse lookup works
reverse_table.lookup(tf.constant(sent, dtype=tf.int64)).eval()

array([[b'anarchism', b'originated', b'as', b'a', b'term', b'of', b'abuse',
        b'first', b'used', b'against', b'early', b'working', b'class',
        b'radicals', b'including', b'the', b'diggers', b'of', b'the',
        b'english', b'revolution', b'and', b'the', b'sans', b'culottes',
        b'of', b'the', b'french', b'revolution'],
       [b'twenty', b'years', b'later', b'in', b'one', b'eight', b'six',
        b'four', b'the', b'international', b'workingmen', b's',
        b'association', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>',
        b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>',
        b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>'],
       [b'sometimes', b'called', b'the', b'first', b'international',
        b'united', b'some', b'diverse', b'european', b'revolutionary',
        b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>',
        b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>',
        b'<PAD>', b'<PAD>', b'<PAD>', b'<PAD>',

In [9]:
sess.close()