In [1]:
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import datasets, iterators, initializers
from chainer.dataset import concat_examples
import numpy as np

In [2]:
def reset_iter_state(iterator):
    iterator.epoch = 0
    iterator.current_position = 0
    iterator.is_new_epoch = False
    iterator._pushed_position = None

In [8]:
# Parameters
learning_rate = 0.001
training_epochs = 100
batch_size = 5
display_step = 5

# Network Parameters
n_hiddens = 64
num_classes = 10 # MNIST total classes (0-9 digits)
dropout = 0.75 # Dropout, probability to keep units

In [11]:
# Load data
train, test = chainer.datasets.get_mnist(withlabel=True, ndim=2, scale=1.)
reduced_train = chainer.iterators.SerialIterator(dataset=train, batch_size=5000, repeat=True, shuffle=True).next()
reduced_test = chainer.iterators.SerialIterator(dataset=train, batch_size=1000, repeat=True, shuffle=True).next()
train_iter = chainer.iterators.SerialIterator(dataset=reduced_train, batch_size=batch_size, 
                                              repeat=True, shuffle=True)
test_iter = chainer.iterators.SerialIterator(dataset=reduced_test, batch_size=batch_size, 
                                             repeat=False, shuffle=False)

In [12]:
class RNN(chainer.Chain):
    def __init__(self):
        super(RNN, self).__init__()
        with self.init_scope():
            self.lstm = L.LSTM(in_size=28, out_size=n_hiddens)
            self.fc = L.Linear(in_size=n_hiddens, out_size=num_classes)
    
    def reset_state(self):
        self.lstm.reset_state()
    
    def forward_one_step(self, x):
        self.reset_state()
        for i in range(x.shape[1]):
            self.lstm(x[:, i:i+1])
        h = self.fc(self.lstm.h)
        return h

In [13]:
model = RNN()
optimiser = chainer.optimizers.Adam(alpha=learning_rate)
optimiser.setup(model)

In [14]:
reset_iter_state(train_iter)

while train_iter.epoch < training_epochs:
    train_batch = train_iter.next()
    image_train, target_train = concat_examples(train_batch)
    
    # Calculate the prediction of the network
    prediction_train = model.forward_one_step(image_train)

    # Calculate the loss with softmax_cross_entropy
    loss = F.softmax_cross_entropy(prediction_train, target_train)

    # Calculate the gradients in the network
    model.cleargrads()
    loss.backward()
    
    # Update all the trainable paremters
    optimiser.update()
    # --------------------- until here ---------------------

    # Check the validation accuracy of prediction after every epoch
    if train_iter.is_new_epoch and (train_iter.epoch + 1) % display_step == 0:
        # If this iteration is the final iteration of the current epoch
        with chainer.using_config('train', False):
            test_losses = []
            test_accuracies = []
            while True:
                test_batch = test_iter.next()
                image_test, target_test = concat_examples(test_batch)

                # Forward the test data
                prediction_test = model.forward_one_step(image_test)

                # Calculate the loss
                loss_test = F.softmax_cross_entropy(prediction_test, target_test)
                test_losses.append(loss_test.data)

                # Calculate the accuracy
                accuracy = F.accuracy(prediction_test, target_test)
                test_accuracies.append(accuracy.data)

                if test_iter.is_new_epoch:
                    reset_iter_state(test_iter)
                    break

            print('epoch={:03d} train_loss={:.05f} val_loss={:.05f} val_accuracy={:.05f}'.format(
                train_iter.epoch + 1, float(loss.data), np.mean(test_losses), np.mean(test_accuracies)))

epoch=005 train_loss=0.21106 val_loss=0.25334 val_accuracy=0.92300
epoch=010 train_loss=0.09774 val_loss=0.19085 val_accuracy=0.94700
epoch=015 train_loss=0.00518 val_loss=0.20499 val_accuracy=0.94800
epoch=020 train_loss=0.00082 val_loss=0.21573 val_accuracy=0.94000


KeyboardInterrupt: 