In [8]:
from __future__ import print_function

from argparse import Namespace

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions

In [9]:
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

In [16]:
def main(args):
    model = L.Classifier(MLP(args.unit, 10))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # Take a snapshot for each epoch
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())
    
    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)
        
    if args.freeze:
        # freeze all layers but the last fc layer
        model.predictor.disable_update()
        model.predictor.l3.enable_update()
        
    print('freezing previous layers: ',
          not model.predictor.l1.update_enabled)
    
    # Run the training
    trainer.run()

In [17]:
args = {
        'batchsize': 100,
        'epoch': 5,
        'gpu': 0,
        'out': 'result',
        'resume': False,
        'unit': 1000,
        'freeze': False,
       }
args = Namespace(**args)

## Sometimes we want to first freezing the earlier layers

In [23]:
args.freeze = True
args.resume = None
main(args)

freezing previous layers:  True
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
[J     total [#.................................................]  3.33%
this epoch [########..........................................] 16.67%
       100 iter, 0 epoch / 5 epochs
       inf iters/sec. Estimated time to finish: 0:00:00.
[4A[J     total [###...............................................]  6.67%
this epoch [################..................................] 33.33%
       200 iter, 0 epoch / 5 epochs
    325.07 iters/sec. Estimated time to finish: 0:00:08.613542.
[4A[J     total [#####.............................................] 10.00%
this epoch [#########################.........................] 50.00%
       300 iter, 0 epoch / 5 epochs
    324.08 iters/sec. Estimated time to finish: 0:00:08.331194.
[4A[J     total [######............................................] 13.33%
this epoch [#################################.........

## When we want to resume training for all layers, from a previous snapshot that freezes layers, such error occurs

In [21]:
args.freeze = False
args.resume = 'result/snapshot_iter_3000'
main(args)

KeyError: 'updater/optimizer:main/predictor/l2/W/m is not a file in the archive'