In [1]:
import chainer
import chainer.functions as F
import chainer.links as L
from chainer.dataset import concat_examples
from chainer import initializers
import numpy as np
from chainer.utils import type_check
import matplotlib.pyplot as plt

In [2]:
class LogisticRegression(chainer.Chain):
    def __init__(self, in_size, out_size):
        super(LogisticRegression, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(in_size, out_size)
            
    def __call__(self, x):
        h = self.l1(x)
        return h

In [3]:
# Params
learning_rate = 0.01
training_epochs = 100
display_step = 1

In [4]:
# Load data
train, test = chainer.datasets.get_mnist(withlabel=True, ndim=1, scale=1.)
reduced_train = chainer.iterators.SerialIterator(dataset=train, batch_size=5000, repeat=True, shuffle=True).next()
reduced_test = chainer.iterators.SerialIterator(dataset=train, batch_size=500, repeat=True, shuffle=True).next()
train_iter = chainer.iterators.SerialIterator(dataset=reduced_train, batch_size=100, repeat=True, shuffle=True)
test_iter = chainer.iterators.SerialIterator(dataset=reduced_test, batch_size=100, repeat=True, shuffle=True)

In [5]:
# Build model & optimiser
model = LogisticRegression(784, 10)
optimiser = chainer.optimizers.SGD(lr=learning_rate)
optimiser.setup(model)

In [6]:
def reset_iter_state(iterator):
    iterator.epoch = 0
    iterator.current_position = 0
    iterator.is_new_epoch = False
    iterator._pushed_position = None

In [7]:
reset_iter_state(train_iter)

while train_iter.epoch < training_epochs:
    train_batch = train_iter.next()
    image_train, target_train = concat_examples(train_batch)
    
    # Calculate the prediction of the network
    prediction_train = model(image_train)

    # Calculate the loss with softmax_cross_entropy
    loss = F.softmax_cross_entropy(prediction_train, target_train)

    # Calculate the gradients in the network
    model.cleargrads()
    loss.backward()

    # Update all the trainable paremters
    optimiser.update()
    # --------------------- until here ---------------------

    # Check the validation accuracy of prediction after every epoch
    if train_iter.is_new_epoch and (train_iter.epoch + 1) % display_step == 0:
        # If this iteration is the final iteration of the current epoch

        test_losses = []
        test_accuracies = []
        while True:
            test_batch = test_iter.next()
            image_test, target_test = concat_examples(test_batch)

            # Forward the test data
            prediction_test = model(image_test)

            # Calculate the loss
            loss_test = F.softmax_cross_entropy(prediction_test, target_test)
            test_losses.append(loss_test.data)

            # Calculate the accuracy
            accuracy = F.accuracy(prediction_test, target_test)
            test_accuracies.append(accuracy.data)

            if test_iter.is_new_epoch:
                reset_iter_state(test_iter)
                break

        print('epoch={:03d} train_loss={:.05f} val_loss={:.05f} val_accuracy={:.05f}'.format(
            train_iter.epoch, float(loss.data), np.mean(test_losses), np.mean(test_accuracies)))

epoch=002 train_loss=1.97720 val_loss=1.90184 val_accuracy=0.51000
epoch=003 train_loss=1.59115 val_loss=1.58910 val_accuracy=0.69200
epoch=004 train_loss=1.37699 val_loss=1.37247 val_accuracy=0.72200
epoch=005 train_loss=1.23446 val_loss=1.21922 val_accuracy=0.75400
epoch=006 train_loss=1.10783 val_loss=1.10697 val_accuracy=0.78800
epoch=007 train_loss=0.96216 val_loss=1.02269 val_accuracy=0.79400
epoch=008 train_loss=0.97025 val_loss=0.95666 val_accuracy=0.80000
epoch=009 train_loss=0.96979 val_loss=0.90328 val_accuracy=0.81400
epoch=010 train_loss=0.77683 val_loss=0.86041 val_accuracy=0.82200
epoch=011 train_loss=0.83561 val_loss=0.82285 val_accuracy=0.82400
epoch=012 train_loss=0.77678 val_loss=0.79236 val_accuracy=0.82200
epoch=013 train_loss=0.71521 val_loss=0.76631 val_accuracy=0.82000
epoch=014 train_loss=0.73232 val_loss=0.74301 val_accuracy=0.82600
epoch=015 train_loss=0.72577 val_loss=0.72308 val_accuracy=0.83000
epoch=016 train_loss=0.67781 val_loss=0.70545 val_accuracy=0.8