In [34]:
import argparse
import copy
import numpy as np

import chainer
from chainer.datasets import TransformDataset
from chainer.optimizer import WeightDecay
from chainer import serializers
from chainer import training
from chainer.training import extensions
from chainer.training import triggers

from chainercv.datasets import voc_bbox_label_names
from chainercv.datasets import VOCBboxDataset
from chainercv.extensions import DetectionVOCEvaluator
from chainercv.links.model.ssd import GradientScaling
from chainercv.links.model.ssd import multibox_loss
from chainercv.links import SSD300
from chainercv.links import SSD512
from chainercv import transforms

from chainercv.links.model.ssd import random_crop_with_bbox_constraints
from chainercv.links.model.ssd import random_distort
from chainercv.links.model.ssd import resize_with_random_interpolation

In [35]:
class ConcatenatedDataset(chainer.dataset.DatasetMixin):

    def __init__(self, *datasets):
        self._datasets = datasets

    def __len__(self):
        return sum(len(dataset) for dataset in self._datasets)

    def get_example(self, i):
        if i < 0:
            raise IndexError
        for dataset in self._datasets:
            if i < len(dataset):
                return dataset[i]
            i -= len(dataset)
        raise IndexError

In [36]:
class MultiboxTrainChain(chainer.Chain):

    def __init__(self, model, alpha=1, k=3):
        super(MultiboxTrainChain, self).__init__()
        with self.init_scope():
            self.model = model
        self.alpha = alpha
        self.k = k

    def __call__(self, imgs, gt_mb_locs, gt_mb_labels):
        mb_locs, mb_confs = self.model(imgs)
        loc_loss, conf_loss = multibox_loss(
            mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k)
        loss = loc_loss * self.alpha + conf_loss

        chainer.reporter.report(
            {'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
            self)

        return loss

In [127]:
class Transform(object):

    def __init__(self, coder, size, mean):
        # to send cpu, make a copy
        self.coder = copy.copy(coder)
        self.coder.to_cpu()

        self.size = size
        self.mean = mean

    def __call__(self, in_data):
        # There are five data augmentation steps
        # 1. Color augmentation
        # 2. Random expansion
        # 3. Random cropping
        # 4. Resizing with random interpolation
        # 5. Random horizontal flipping

        img, bbox, label = in_data

        # 1. Color augmentation
        img = random_distort(img)

        # 2. Random expansion
        if np.random.randint(2):
            img, param = transforms.random_expand(
                img, fill=self.mean, return_param=True)
            bbox = transforms.translate_bbox(
                bbox, y_offset=param['y_offset'], x_offset=param['x_offset'])

        # 3. Random cropping
        img, param = random_crop_with_bbox_constraints(
            img, bbox, return_param=True)
        bbox, param = transforms.crop_bbox(
            bbox, y_slice=param['y_slice'], x_slice=param['x_slice'],
            allow_outside_center=False, return_param=True)
        label = label[param['index']]

        # 4. Resizing with random interpolatation
        _, H, W = img.shape
        img = resize_with_random_interpolation(img, (self.size, self.size))
        bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))

        # 5. Random horizontal flipping
        img, params = transforms.random_flip(
            img, x_random=True, return_param=True)
        bbox = transforms.flip_bbox(
            bbox, (self.size, self.size), x_flip=params['x_flip'])

        # Preparation for SSD network
        img -= self.mean
        mb_loc, mb_label = self.coder.encode(bbox, label)

        return img, mb_loc, mb_label


In [128]:
model = SSD300(n_fg_class=len(voc_bbox_label_names), pretrained_model='imagenet')

In [129]:
model.use_preset('evaluate')
train_chain = MultiboxTrainChain(model)


In [130]:
# Download voc2007 http://host.robots.ox.ac.uk/pascal/VOC/voc2007/#data
# Extract content in train and test directory (just take the VOC2007 folder content)
data_path = 'data/ssd/voc2007/{}'
OUTPUT = 'result'
BATCH_SIZE = 3

In [131]:
train = TransformDataset(
    VOCBboxDataset(
        year='2007', split='trainval', data_dir=data_path.format('train')),
    Transform(model.coder, model.insize, model.mean))
train_iter = chainer.iterators.MultiprocessIterator(train, BATCH_SIZE)

In [132]:
test = VOCBboxDataset(
    year='2007', split='test',
    use_difficult=True, return_difficult=True, data_dir=data_path.format('test'))
test_iter = chainer.iterators.SerialIterator(
    test, BATCH_SIZE, repeat=False, shuffle=False)

In [133]:
optimizer = chainer.optimizers.MomentumSGD()
optimizer.setup(train_chain)

In [134]:
# Need to find explaination
for param in train_chain.params():
    if param.name == 'b':
        param.update_rule.add_hook(GradientScaling(2))
    else:
        param.update_rule.add_hook(WeightDecay(0.0005))

In [135]:
updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (120000, 'iteration'), OUTPUT)
trainer.extend(
    extensions.ExponentialShift('lr', 0.1, init=1e-3),
    trigger=triggers.ManualScheduleTrigger([80000, 100000], 'iteration'))

In [136]:
trainer.extend(
    DetectionVOCEvaluator(
        test_iter, model, use_07_metric=True,
        label_names=voc_bbox_label_names),
    trigger=(10000, 'iteration'))

In [137]:
log_interval = 10, 'iteration'
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(extensions.PrintReport(
    ['epoch', 'iteration', 'lr',
     'main/loss', 'main/loss/loc', 'main/loss/conf',
     'validation/main/map']),
)
#     trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))

In [138]:
trainer.extend(extensions.snapshot(), trigger=(10000, 'iteration'))
trainer.extend(
    extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'),
    trigger=(120000, 'iteration'))

In [None]:
trainer.run()