In [1]:
import chainer
import chainer.functions as F
import chainer.links as L

In [2]:
train, test = chainer.datasets.get_cifar10()

In [15]:
class CNN(chainer.Chain):
    def __init__(self, n_out):
        super().__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(None, 16, 3, 1)
            self.conv2 = L.Convolution2D(16, 32, 3, 2)
            self.conv3 = L.Convolution2D(32, 32, 3, 1)
            self.conv4 = L.Convolution2D(32, 64, 3, 2)
            self.conv5 = L.Convolution2D(64, 64, 3, 1)
            self.conv6 = L.Convolution2D(64, 128, 3, 2)
            self.fc7 = L.Linear(None, 100)
            self.fc8 = L.Linear(100, n_out)
 
    def __call__(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        h = F.relu(self.conv5(h))
        h = F.relu(self.conv6(h))
        h = F.relu(self.fc7(h))
        h = self.fc8(h)
        return h

In [16]:
from chainer import optimizers, iterators, training
from chainer.training import extensions
def train_model(model, batch_size=64, epoch=20):
    classifier_model = L.Classifier(model)
 
    # 2. Setup an optimizer
    optimizer = optimizers.Adam()
    optimizer.setup(classifier_model)
 
    # 3. Load the CIFAR-10 dataset
    train, test = chainer.datasets.get_cifar10()
 
    # 4. Setup an Iterator
    train_iter = iterators.SerialIterator(train, batch_size)
    test_iter = iterators.SerialIterator(test, batch_size,
                                         repeat=False, shuffle=False)
 
    # 5. Setup an Updater
    updater = training.StandardUpdater(train_iter, optimizer)
    # 6. Setup a trainer (and extensions)
    trainer = training.Trainer(updater, (epoch, 'epoch'))
 
    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, classifier_model))
 
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
    trainer.extend(extensions.PlotReport(
        ['main/loss', 'validation/main/loss'],
        x_key='epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(
        ['main/accuracy', 'validation/main/accuracy'],
        x_key='epoch',
        file_name='accuracy.png'))
 
    trainer.extend(extensions.ProgressBar())
 
    # Run the training
    trainer.run()

In [17]:
train_model(CNN(10))

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
[J

KeyboardInterrupt: 