In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import time
import tensorflow as tf
import numpy as np

In [2]:
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('num_epochs', 5, 'Number of epochs to run trainer.')
flags.DEFINE_integer('batch_size', 4096, 'Batch size.')

In [3]:
data_files = tf.train.match_filenames_once("data_v2/*")
train_files = data_files

In [4]:
def extract_features(serialized_example):
    """
    Extracts a `dict` of named features from the serialized `tf.train.Example`
    """
    return tf.parse_single_example(
        serialized=serialized_example,
        features={
            'input_dense_dimensions': tf.FixedLenFeature([2], dtype=tf.int64),
            'sparse_index_dimensions': tf.FixedLenFeature([2], dtype=tf.int64),
            'input': tf.FixedLenFeature([80], dtype=tf.int64),
            'label': tf.FixedLenFeature([1], dtype=tf.int64),
            'label_length': tf.FixedLenFeature([1], dtype=tf.int64),
        }
    )

def deserialize_example(serialized_example):
    """
    Converts a serialized `tf.train.Example` to FP32 Tensors
    """
    features = extract_features(serialized_example)
    shape_sparse = tf.cast(features['sparse_index_dimensions'], tf.int32)
    indices = tf.reshape(tf.cast(features['input'], tf.int32), shape_sparse)
    values = tf.ones([shape_sparse[0]])
    shape_dense = tf.cast(features['input_dense_dimensions'], tf.int32)
    input = tf.sparse_to_dense(indices, (20, 111), values)
    label = tf.one_hot(features['label'][0], 99, on_value=1., off_value=0., dtype=tf.float32)
    return input, label

In [5]:
def read_and_decode(filename_queue):
    """
    Read and Deserialize a single `tf.train.Example` from a TFRecord file.
    """
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    return deserialize_example(serialized_example)

In [6]:
def get_batch(batch_size=None, num_epochs=None, train=True):
    """
    Read in shuffled `inputs` and `labels` from either train or val files.
    
    Returns:
        `inputs` : [batch_size, 20, 111]
        `labels` : [batch_size, 99]
    """
    batch_size = batch_size or 128
    num_epochs = num_epochs or 1
    
    with tf.name_scope('input'):
        filename_queue = tf.train.string_input_producer(train_files, num_epochs=num_epochs)
        input, label = read_and_decode(filename_queue)
        inputs, labels = tf.train.shuffle_batch(
            [input, label],
            batch_size=batch_size,
            num_threads=20,
            capacity=4*batch_size,
            min_after_dequeue=batch_size
        )
        return inputs, labels

## Train the ORNL model

In [7]:
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.layers import LSTM
from keras.optimizers import RMSprop
from keras.objectives import categorical_crossentropy

Using TensorFlow backend.


In [None]:
rnn_size = 50
seq_length = 20
num_vocab = 99
num_classes = 12

batch_size=2048
num_epochs=10

with tf.name_scope("nn"):
    model = Sequential()
    model.add( LSTM(rnn_size, input_shape=(seq_length, num_vocab + num_classes)) )
    model.add( Dense(num_vocab) )
    model.add( Activation('softmax') )


inputs, labels = get_batch(batch_size=batch_size, num_epochs=num_epochs)    
logits = model(inputs)
loss = tf.reduce_mean(categorical_crossentropy(labels, logits))
    
train_step = tf.train.RMSPropOptimizer(0.0005).minimize(loss)

In [None]:
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    keras.backend.set_session(sess)
    
    init_op = tf.group(tf.initialize_all_variables(),
                       tf.initialize_local_variables())
    sess.run(init_op)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
        step = 0
        while not coord.should_stop():
            start_time = time.time()
            
            _, loss_value = sess.run([train_step, loss]) #, feed_dict={
            #        inputs: batch[0].eval(),
            #        labels: batch[1].eval()
            #})
        
            duration = time.time() - start_time

            # Print an overview fairly often.
            if step % 50 == 0:
                print(step, loss_value, duration)
    
            step += 1
            
    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (10, step))
    finally:
        # When done, ask the threads to stop.
        coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)

0 4.61715 1.68294811249
50 4.61187 0.701879024506
100 4.53142 0.672744035721
150 3.98735 0.717196941376
200 3.68683 0.696604967117
250 3.53813 0.708003044128
300 3.45296 0.711354970932
350 3.44111 0.685532093048
400 3.34732 0.696856975555
450 3.36381 0.705395936966
500 3.27871 0.673691034317
550 3.24115 0.702242136002
