In [1]:
import chainer
import chainer.links as L
import chainer.functions as F
from chainer.dataset import concat_examples
import numpy as np

In [2]:
# Parameters
learning_rate = 0.01
training_epochs = 100
batch_size = 128
display_step = 1

# Network Parameters
n_hidden_1 = 256 # 1st layer number of neurons
n_hidden_2 = 256 # 2nd layer number of neurons
num_input = 784 # MNIST data input (img shape: 28*28)
num_classes = 10 # MNIST total classes (0-9 digits)

In [3]:
# Load data
train, test = chainer.datasets.get_mnist(withlabel=True, ndim=1, scale=1.)
reduced_train = chainer.iterators.SerialIterator(dataset=train, batch_size=5000, repeat=True, shuffle=True).next()
reduced_test = chainer.iterators.SerialIterator(dataset=train, batch_size=1000, repeat=True, shuffle=True).next()
train_iter = chainer.iterators.SerialIterator(dataset=reduced_train, batch_size=batch_size, repeat=True, shuffle=True)
test_iter = chainer.iterators.SerialIterator(dataset=reduced_test, batch_size=batch_size, repeat=True, shuffle=True)

In [4]:
class NerualNetwork(chainer.Chain):
    def __init__(self):
        super(NerualNetwork, self).__init__()
        with self.init_scope():
            self.fc1 = L.Linear(num_input, n_hidden_1, nobias=False,
                               initial_bias=0)
            self.fc2 = L.Linear(n_hidden_1, n_hidden_2)
            self.fc3 = L.Linear(n_hidden_2, num_classes)
            
    def __call__(self, x):
        z1 = F.relu(self.fc1(x))
        z2 = F.relu(self.fc2(z1))
        h = self.fc3(z2)
        return h

In [5]:
model = NerualNetwork()
optimiser = chainer.optimizers.SGD(lr=learning_rate)
optimiser.setup(model)

In [6]:
def reset_iter_state(iterator):
    iterator.epoch = 0
    iterator.current_position = 0
    iterator.is_new_epoch = False
    iterator._pushed_position = None

In [7]:
reset_iter_state(train_iter)

while train_iter.epoch < training_epochs:
    train_batch = train_iter.next()
    image_train, target_train = concat_examples(train_batch)
    
    # Calculate the prediction of the network
    prediction_train = model(image_train)

    # Calculate the loss with softmax_cross_entropy
    loss = F.softmax_cross_entropy(prediction_train, target_train)

    # Calculate the gradients in the network
    model.cleargrads()
    loss.backward()
    
    # Update all the trainable paremters
    optimiser.update()
    # --------------------- until here ---------------------

    # Check the validation accuracy of prediction after every epoch
    if train_iter.is_new_epoch and (train_iter.epoch + 1) % display_step == 0:
        # If this iteration is the final iteration of the current epoch

        test_losses = []
        test_accuracies = []
        while True:
            test_batch = test_iter.next()
            image_test, target_test = concat_examples(test_batch)

            # Forward the test data
            prediction_test = model(image_test)

            # Calculate the loss
            loss_test = F.softmax_cross_entropy(prediction_test, target_test)
            test_losses.append(loss_test.data)

            # Calculate the accuracy
            accuracy = F.accuracy(prediction_test, target_test)
            test_accuracies.append(accuracy.data)

            if test_iter.is_new_epoch:
                reset_iter_state(test_iter)
                break

        print('epoch={:03d} train_loss={:.05f} val_loss={:.05f} val_accuracy={:.05f}'.format(
            train_iter.epoch + 1, float(loss.data), np.mean(test_losses), np.mean(test_accuracies)))

epoch=002 train_loss=2.22040 val_loss=2.22141 val_accuracy=0.25879
epoch=003 train_loss=2.12945 val_loss=2.10679 val_accuracy=0.46094
epoch=004 train_loss=1.91256 val_loss=1.96766 val_accuracy=0.59375
epoch=005 train_loss=1.78751 val_loss=1.80304 val_accuracy=0.66602
epoch=006 train_loss=1.61464 val_loss=1.61648 val_accuracy=0.70703
epoch=007 train_loss=1.43525 val_loss=1.42558 val_accuracy=0.73730
epoch=008 train_loss=1.24614 val_loss=1.24967 val_accuracy=0.75781
epoch=009 train_loss=1.22788 val_loss=1.09275 val_accuracy=0.78418
epoch=010 train_loss=1.02042 val_loss=0.96866 val_accuracy=0.80176
epoch=011 train_loss=0.84979 val_loss=0.87127 val_accuracy=0.81250
epoch=012 train_loss=0.71199 val_loss=0.79256 val_accuracy=0.81738
epoch=013 train_loss=0.63298 val_loss=0.73648 val_accuracy=0.83203
epoch=014 train_loss=0.68419 val_loss=0.68782 val_accuracy=0.84082
epoch=015 train_loss=0.64720 val_loss=0.64379 val_accuracy=0.84863
epoch=016 train_loss=0.63590 val_loss=0.61195 val_accuracy=0.8