In [1]:
import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell

import numpy as np

# Import MNIST dataset

In [2]:
# load MNIST dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets("/tmp/data/", one_hot=True)

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [3]:
# total amount of train and test sets
print 'Total train MNIST images : ', len(mnist_data.train.images)
print 'Total test MNIST images: ', len(mnist_data.test.images)

Total train MNIST images :  55000
Total test MNIST images:  10000


### When it comes to classifying images via RNN, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample.

In [4]:
# Parameters
learning_rate = 0.001
total_epoch = 200000

# Generally, the smaller the batch size the noisier the updates, 
# so if you decrease the batch size you should probably decrease the learning rate, 
# and train for more iterations.
batch_size = 256

display_step = 10

# Network Parameters
# MNIST data input (img shape: 28*28)
total_input = 28 

# total timestep
total_timestep = 28

# hidden layer num of features
total_hidden = 256

# MNIST total classes (0-9 digits)
total_classes = 10 

# tensorflow graph input
x = tf.placeholder("float", [None, total_timestep, total_input])
y = tf.placeholder("float", [None, total_classes])

# define model weight and bias
W = {'weight': tf.Variable(tf.random_normal([total_hidden, total_classes]))}

b = {'bias': tf.Variable(tf.random_normal([total_classes]))}

# Construct RNN (LSTM)

In [5]:
def RNN(x, W, b):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, total_timestep, total_input)
    # Required shape: 'total_timestep' tensors list of shape (batch_size, total_input)
    
    # Permuting batch_size and total timestep
    x = tf.transpose(x, perm=[1, 0, 2])
    
    # Reshaping to (total_timestep*batch_size, total_input)
    x = tf.reshape(x, [-1, total_input])
    
    # Split to get a list of 'total_timestep' tensors of shape (batch_size, total_input)
    x = tf.split(0, total_timestep, x)

    # Define a LSTM cell with tensorflow
    lstm_cell = rnn_cell.BasicLSTMCell(total_hidden, forget_bias=1.0)

    # Get LSTM cell output
    outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], W['weight']) + b['bias']

In [6]:
rnn_model = RNN(x, W, b)

# define loss
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(rnn_model, y))

# adam optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# model accuracy
correct_prediction = tf.equal(tf.argmax(rnn_model,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Initializing the variables
init = tf.initialize_all_variables()

# Run the RNN model!

In [7]:
# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    step = 1
    
    # training starts here!
    while step * batch_size < total_epoch:
        batch_x, batch_y = mnist_data.train.next_batch(batch_size)
        
        # transform data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, total_timestep, total_input))
        
        # execute optimization (backprop)
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        
        if step % display_step == 0:
            # calculate batch accuracy
            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
            
            # calculate batch loss
            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
            print 'Iterations: ' + str(step*batch_size) + ', Minibatch Loss= ' + \
                  '{:.6f}'.format(loss) + ', Training Accuracy= ' + \
                  '{:.5f}'.format(acc)
        step += 1  
    print 'Training is complete!'
    print ''
    
    # calculate accuracy for 10000 mnist test images
    test_len = 10000
    test_data = mnist_data.test.images[:test_len].reshape((-1, total_timestep, total_input))
    test_label = mnist_data.test.labels[:test_len]
    
    print 'Test Accuracy:', sess.run(accuracy, feed_dict={x: test_data, y: test_label})

Iterations: 2560, Minibatch Loss= 1.709867, Training Accuracy= 0.39844
Iterations: 5120, Minibatch Loss= 1.086316, Training Accuracy= 0.69141
Iterations: 7680, Minibatch Loss= 1.091629, Training Accuracy= 0.62109
Iterations: 10240, Minibatch Loss= 0.749779, Training Accuracy= 0.76953
Iterations: 12800, Minibatch Loss= 0.875684, Training Accuracy= 0.71484
Iterations: 15360, Minibatch Loss= 0.555176, Training Accuracy= 0.82031
Iterations: 17920, Minibatch Loss= 0.425183, Training Accuracy= 0.86328
Iterations: 20480, Minibatch Loss= 0.287345, Training Accuracy= 0.89844
Iterations: 23040, Minibatch Loss= 0.240682, Training Accuracy= 0.90234
Iterations: 25600, Minibatch Loss= 0.317411, Training Accuracy= 0.87500
Iterations: 28160, Minibatch Loss= 0.293748, Training Accuracy= 0.90234
Iterations: 30720, Minibatch Loss= 0.324046, Training Accuracy= 0.87500
Iterations: 33280, Minibatch Loss= 0.200775, Training Accuracy= 0.94922
Iterations: 35840, Minibatch Loss= 0.222271, Training Accuracy= 0.9