# Training a SSD module in MXNet

In [1]:
%matplotlib inline
import d2lzh as d2l
import mxnet as mx
from mxnet import autograd, contrib, gluon, image, init, nd, symbol
from mxnet.gluon import loss as gloss, nn
import time

In [2]:
def cls_predictor(num_anchors, num_classes):
    return nn.Conv2D(num_anchors * (num_classes + 1), kernel_size=3,
                     padding=1)

def bbox_predictor(num_anchors):
    return nn.Conv2D(num_anchors * 4, kernel_size=3, padding=1)

def forward(x, block):
    block.initialize()
    return block(x)

Y1 = forward(nd.zeros((2, 8, 20, 20)), cls_predictor(5, 10))
Y2 = forward(nd.zeros((2, 16, 10, 10)), cls_predictor(3, 10))
(Y1.shape, Y2.shape)

((2, 55, 20, 20), (2, 33, 10, 10))

In [3]:
def flatten_pred(pred):
    return pred.transpose((0, 2, 3, 1)).flatten()

def concat_preds(preds):
    return symbol.concat(*[flatten_pred(p) for p in preds], dim=1)

In [None]:
def down_sample_blk(num_channels):
    blk = nn.HybridSequential()
    for _ in range(2):
        blk.add(nn.Conv2D(num_channels, kernel_size=3, padding=1),
                nn.BatchNorm(in_channels=num_channels),
                nn.Activation('relu'))
    blk.add(nn.MaxPool2D(2))
    return blk

In [4]:
def base_net():
#     blk = nn.Sequential()
    blk = nn.HybridSequential()
    for num_filters in [16, 32, 64]:
        blk.add(down_sample_blk(num_filters))
    return blk

forward(nd.zeros((2, 3, 256, 256)), base_net()).shape

(2, 64, 32, 32)

In [17]:
def get_blk(i):
    if i == 0:
        blk = base_net()
    elif i == 4:
        blk = nn.GlobalMaxPool2D()
    else:
        blk = down_sample_blk(128)
    return blk


def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):
    Y = blk(X)
    anchors = contrib.symbol.MultiBoxPrior(Y, sizes=size, ratios=ratio)
    cls_preds = cls_predictor(Y)
    bbox_preds = bbox_predictor(Y)
    return (Y, anchors, cls_preds, bbox_preds)


sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],
         [0.88, 0.961]]
ratios = [[1, 2, 0.5]] * 5
num_anchors = len(sizes[0]) + len(ratios[0]) - 1


class TinySSD(nn.HybridBlock):
    def __init__(self, num_classes, **kwargs):
        super(TinySSD, self).__init__(**kwargs)
        self.num_classes = num_classes
        for i in range(5):
            # self.blk_i = get_blk(i)
            setattr(self, 'blk_%d' % i, get_blk(i))
            setattr(self, 'cls_%d' % i, cls_predictor(num_anchors,
                                                      num_classes))
            setattr(self, 'bbox_%d' % i, bbox_predictor(num_anchors))

    # For HybridBlock, using hybrid_forward(). Select  NDArray or Symbol according to F.
    def hybrid_forward(self, F, X):
        print('F: ', F)
        print('x: ', X)
        anchors, cls_preds, bbox_preds = [None] * 5, [None] * 5, [None] * 5
        for i in range(5):
            # getattr(self, 'blk_%d' % i) == self.blk_i
            X, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(
                X, getattr(self, 'blk_%d' % i), sizes[i], ratios[i],
                getattr(self, 'cls_%d' % i), getattr(self, 'bbox_%d' % i))
            
        print('Anchor: ', type(anchors[0]))
        print('cls_pred: ', type(cls_preds[0]))
        print('bbox_pred: ', type(bbox_preds[0]))
        return (F.concat(*anchors, dim=1),
                concat_preds(cls_preds).reshape((0, -1, self.num_classes + 1)), concat_preds(bbox_preds))

In [18]:
net = TinySSD(num_classes=1)
net.initialize()
net.hybridize()
X = nd.zeros((32, 3, 256, 256))
anchors, cls_preds, bbox_preds = net(X)

print('output anchors:', anchors.shape, type(anchors))
print('output class preds:', cls_preds.shape, type(cls_preds))
print('output bbox preds:', bbox_preds.shape)

F:  <module 'mxnet.symbol' from '/anaconda3/envs/gluon/lib/python3.7/site-packages/mxnet/symbol/__init__.py'>
x:  <Symbol data>
Anchor:  <class 'mxnet.symbol.symbol.Symbol'>
cls_pred:  <class 'mxnet.symbol.symbol.Symbol'>
bbox_pred:  <class 'mxnet.symbol.symbol.Symbol'>
output anchors: (1, 5444, 4) <class 'mxnet.ndarray.ndarray.NDArray'>
output class preds: (32, 5444, 2) <class 'mxnet.ndarray.ndarray.NDArray'>
output bbox preds: (32, 21776)


In [19]:
batch_size = 32
train_iter, _ = d2l.load_data_pikachu(batch_size)

In [20]:
ctx, net = d2l.try_gpu(), TinySSD(num_classes=1)
net.initialize(init=init.Xavier(), ctx=ctx)
net.hybridize()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.02})
# print(type(net), net)

In [21]:
cls_loss = gloss.SoftmaxCrossEntropyLoss()
bbox_loss = gloss.L1Loss()


def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):
    cls = cls_loss(cls_preds, cls_labels)
    bbox = bbox_loss(bbox_preds * bbox_masks, bbox_labels * bbox_masks)
    return cls + bbox


def cls_eval(cls_preds, cls_labels):
    return (cls_preds.argmax(axis=-1) == cls_labels).sum().asscalar()


def bbox_eval(bbox_preds, bbox_labels, bbox_masks):
    return ((bbox_labels - bbox_preds) * bbox_masks).abs().sum().asscalar()


for epoch in range(20):
    acc_sum, mae_sum, n, m = 0.0, 0.0, 0, 0
    train_iter.reset()  # load the training data
    start = time.time()
    for batch in train_iter:
        X = batch.data[0].as_in_context(ctx)
        Y = batch.label[0].as_in_context(ctx)
        with autograd.record():
            # Generate multi-scale anchors
            anchors, cls_preds, bbox_preds = net(X)
            # Mark categories and offsets for each anchor box
            bbox_labels, bbox_masks, cls_labels = contrib.nd.MultiBoxTarget(
                anchors, Y, cls_preds.transpose((0, 2, 1)))
            l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,
                          bbox_masks)
        l.backward()
        trainer.step(batch_size)
        acc_sum += cls_eval(cls_preds, cls_labels)
        n += cls_labels.size
        mae_sum += bbox_eval(bbox_preds, bbox_labels, bbox_masks)
        m += bbox_labels.size

    print('epoch %2d, class err %.2e, bbox mae %.2e, time %.1f sec' % (
        epoch + 1, 1 - acc_sum / n, mae_sum / m, time.time() - start))

F:  <module 'mxnet.symbol' from '/anaconda3/envs/gluon/lib/python3.7/site-packages/mxnet/symbol/__init__.py'>
x:  <Symbol data>
Anchor:  <class 'mxnet.symbol.symbol.Symbol'>
cls_pred:  <class 'mxnet.symbol.symbol.Symbol'>
bbox_pred:  <class 'mxnet.symbol.symbol.Symbol'>
epoch  1, class err 5.63e-02, bbox mae 5.49e-03, time 250.9 sec
epoch  2, class err 6.97e-03, bbox mae 5.25e-03, time 245.2 sec
epoch  3, class err 5.15e-03, bbox mae 5.06e-03, time 251.7 sec
epoch  4, class err 4.45e-03, bbox mae 4.74e-03, time 249.4 sec
epoch  5, class err 4.00e-03, bbox mae 4.47e-03, time 238.8 sec
epoch  6, class err 3.90e-03, bbox mae 4.42e-03, time 256.3 sec
epoch  7, class err 3.80e-03, bbox mae 4.34e-03, time 253.6 sec
epoch  8, class err 3.61e-03, bbox mae 4.14e-03, time 254.2 sec
epoch  9, class err 3.66e-03, bbox mae 4.14e-03, time 264.5 sec
epoch 10, class err 3.61e-03, bbox mae 4.06e-03, time 250.9 sec
epoch 11, class err 3.63e-03, bbox mae 4.07e-03, time 251.5 sec
epoch 12, class err 3.55e

# Export the module

In [10]:
# net.export('ssd_test')