In [1]:
import tensorflow as tf
from tensorflow.contrib import rnn

import numpy as np

from tqdm import tqdm

In [2]:
NUM_EPOCHS = 10
BATCH_SIZE = 128

CHUNK_SIZE = 28
NUM_CHUNK = 28
RNN_SIZE = 128

NUM_CLASS = 10

In [3]:
train_X, train_y = np.load('data/train_X.npy'), np.load('data/train_y.npy')
test_X, test_y = np.load('data/test_X.npy'), np.load('data/test_y.npy')
valid_X, valid_y = np.load('data/valid_X.npy'), np.load('data/valid_y.npy')
train_X.shape, train_y.shape

((48000, 28, 28, 1), (48000, 10))

In [4]:
X = tf.placeholder(tf.float32, [None, NUM_CHUNK, CHUNK_SIZE])
y = tf.placeholder(tf.float32)

In [5]:
def weight_variable(shape):
    weights = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(weights)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

In [6]:
def recurrent_neural_network(X):
    layer = {
        'weights': weight_variable([RNN_SIZE, NUM_CLASS]),
        'biases': bias_variable([NUM_CLASS])
    }

    X = tf.transpose(X, [1, 0, 2])
    X = tf.reshape(X, [-1, CHUNK_SIZE])
    X = tf.split(X, NUM_CHUNK, 0)

    lstm_cell = rnn.BasicLSTMCell(RNN_SIZE)
    outputs, states = rnn.static_rnn(lstm_cell, X, dtype=tf.float32)

    pred = tf.matmul(outputs[-1], layer['weights']) + layer['biases']
    return pred

In [7]:
def train():
    model = recurrent_neural_network(X)
    loss_op = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=model))
    train_op = tf.train.AdamOptimizer().minimize(loss_op)

    # Evaluations
    correct_prediction = tf.equal(tf.argmax(model, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)

        for epoch in range(NUM_EPOCHS):
            print('Epoch {}'.format(epoch + 1))

            for index, offset in tqdm(list(enumerate(range(0, train_X.shape[0], BATCH_SIZE))), ncols=100):
                xs, ys = train_X[offset: offset + BATCH_SIZE], train_y[offset: offset + BATCH_SIZE]
                xs = xs.reshape((BATCH_SIZE, CHUNK_SIZE, NUM_CHUNK))
                sess.run(train_op, feed_dict={ X: xs, y: ys})

            train_accuracy = accuracy.eval(feed_dict={
                X: train_X.reshape((-1, CHUNK_SIZE, NUM_CHUNK)),
                y: train_y
            })
            validation_accuracy = accuracy.eval(feed_dict={
                X: valid_X.reshape((-1, CHUNK_SIZE, NUM_CHUNK)),
                y: valid_y
            })
            print('Training Accuracy: {}%\nValidation Accuracy: {}%\n'.format(train_accuracy, validation_accuracy))

In [8]:
train()

Epoch 1


100%|█████████████████████████████████████████████████████████████| 375/375 [00:28<00:00, 13.27it/s]


Training Accuracy: 0.8253750205039978%
Validation Accuracy: 0.8270000219345093%

Epoch 2


100%|█████████████████████████████████████████████████████████████| 375/375 [00:29<00:00, 12.90it/s]


Training Accuracy: 0.8412291407585144%
Validation Accuracy: 0.8370000123977661%

Epoch 3


100%|█████████████████████████████████████████████████████████████| 375/375 [00:28<00:00, 13.14it/s]


Training Accuracy: 0.8569999933242798%
Validation Accuracy: 0.8515833616256714%

Epoch 4


100%|█████████████████████████████████████████████████████████████| 375/375 [00:28<00:00, 13.03it/s]


Training Accuracy: 0.8658541440963745%
Validation Accuracy: 0.8600833415985107%

Epoch 5


100%|█████████████████████████████████████████████████████████████| 375/375 [00:28<00:00, 13.25it/s]


Training Accuracy: 0.8705624938011169%
Validation Accuracy: 0.8627499938011169%

Epoch 6


100%|█████████████████████████████████████████████████████████████| 375/375 [00:28<00:00, 13.22it/s]


Training Accuracy: 0.8716874718666077%
Validation Accuracy: 0.8647500276565552%

Epoch 7


100%|█████████████████████████████████████████████████████████████| 375/375 [00:27<00:00, 13.49it/s]


Training Accuracy: 0.8760833144187927%
Validation Accuracy: 0.8705000281333923%

Epoch 8


100%|█████████████████████████████████████████████████████████████| 375/375 [00:29<00:00, 12.84it/s]


Training Accuracy: 0.8817916512489319%
Validation Accuracy: 0.8736666440963745%

Epoch 9


100%|█████████████████████████████████████████████████████████████| 375/375 [00:28<00:00, 13.28it/s]


Training Accuracy: 0.8832708597183228%
Validation Accuracy: 0.8759999871253967%

Epoch 10


100%|█████████████████████████████████████████████████████████████| 375/375 [00:28<00:00, 13.05it/s]


Training Accuracy: 0.8871874809265137%
Validation Accuracy: 0.8793333172798157%

