In [None]:
from gluoncv.data import VOCDetection

train_dataset = VOCDetection(root = '../../../data/VOCdevkit',splits=[(2012, 'train')]) # 这是在VOC2012\ImageSets\Main里面找对应名字的txt
val_dataset = VOCDetection(root = '../../../data/VOCdevkit', splits=[(2012, 'val')])

print('Training images:', len(train_dataset))
print('Validation images:', len(val_dataset))

In [None]:
from gluoncv import model_zoo
# pretrained_base = True 会加载模型
net = model_zoo.get_model('ssd_300_vgg16_atrous_voc', pretrained_base=False)

In [None]:
# 图像增广对于ssd网络来说很重要，下面就进行增广
from gluoncv.data.transforms import presets
from gluoncv import utils
from mxnet import nd
from gluoncv.data.batchify import Tuple, Stack, Pad
from mxnet.gluon.data import DataLoader
from mxnet import autograd
import mxnet as mx

x = nd.zeros(shape=(1,3,512,512))
ctx = mx.gpu() # 单个gpu时
net.initialize(init=init.Xavier(), ctx=ctx)

with autograd.train_mode():
    cls_preds, box_preds, anchors = net(x)

width, height = 512, 512  # 假设512*512大小
train_transform = presets.ssd.SSDDefaultTrainTransform(width, height,anchors)
val_transform = presets.ssd.SSDDefaultValTransform(width, height, anchors)

batch_size = 64 # 试运行，弄小点方便看 .训练时调大点。比如128
# you can make it larger(if your CPU has more cores) to accelerate data loading
num_workers = 2 

# behavior of batchify_fn: stack images, and pad labels
batchify_fn = Tuple(Stack(), Stack(), Stack()) # train_transform 有三个参数，这里也需要3个
train_loader = DataLoader(
    train_dataset.transform(train_transform),
    batch_size,
    shuffle=True,
    batchify_fn=batchify_fn,
    last_batch='rollover',
    num_workers=num_workers) # 是否多线程
# val_loader = DataLoader(
#     val_dataset.transform(val_transform),
#     batch_size,
#     shuffle=False,
#     batchify_fn=batchify_fn,
#     last_batch='keep',
#     num_workers=num_workers)

In [None]:
from gluoncv.loss import SSDMultiBoxLoss
from mxnet import gluon

mbox_loss = SSDMultiBoxLoss()
ce_metric = mx.metric.Loss('CrossEntropy')
smoothl1_metric = mx.metric.Loss('SmoothL1')
trainer = gluon.Trainer(
    net.collect_params(), 'sgd',
    {'learning_rate': 0.001, 'wd': 0.0005, 'momentum': 0.9})

for epoch in range(20):
    train_l_sum, train_acc_sum, n, m, start = 0.0, 0.0, 0, 0, time.time()
    train_loader.reset()  # 从头读取数据
    for ib, batch in enumerate(train_loader):
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
        box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
        with autograd.record():
           cls_preds = []
           box_preds = []
           for x in data:
               cls_pred, box_pred, _ = net(x)
               cls_preds.append(cls_pred)
               box_preds.append(box_pred)
           sum_loss, cls_loss, box_loss = mbox_loss(
                    cls_preds, box_preds, cls_targets, box_targets)
           autograd.backward(sum_loss)
           # since we have already normalized the loss, we don't want to normalize
           # by batch-size anymore
        trainer.step(1)
        ce_metric.update(0, [l * batch_size for l in cls_loss])
        smoothl1_metric.update(0, [l * batch_size for l in box_loss])
        name1, loss1 = ce_metric.get()
        name2, loss2 = smoothl1_metric.get()
#     print('epoch %2d, class err %.2e, bbox mae %.2e, time %.1f sec' % (
#             epoch + 1, 1 - loss1 / n, loss2 / m, time.time() - start))
     print('epoch:', epoch, 'class loss:', loss1, 'bbox loss:', loss2)

In [None]:
# 预测
img = image.imread('d2l-zh/img/pikachu.jpg')
feature = image.imresize(img, 256, 256).astype('float32')
X = feature.transpose((2, 0, 1)).expand_dims(axis=0)

def predict(X):
    anchors, cls_preds, bbox_preds = net(X.as_in_context(ctx))
    cls_probs = cls_preds.softmax().transpose((0, 2, 1))
    output = contrib.nd.MultiBoxDetection(cls_probs, bbox_preds, anchors)
    idx = [i for i, row in enumerate(output[0]) if row[0].asscalar() != -1]
    return output[0, idx]

output = predict(X)

# 显示
d2l.set_figsize((5, 5))

def display(img, output, threshold):
    fig = d2l.plt.imshow(img.asnumpy())
    for row in output:
        score = row[1].asscalar()
        if score < threshold:
            continue
        h, w = img.shape[0:2]
        bbox = [row[2:6] * nd.array((w, h, w, h), ctx=row.context)]
        d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')

display(img, output, threshold=0.3) # 选出置信度不低于0.3的边界框