In [1]:
import sys
sys.path.insert(0,'..')

In [2]:
import mxnet as mx
from mxnet import nd,gluon,contrib,image,autograd,init
from mxnet.gluon import data as gdata,loss as gloss,nn
import gluonbook as gb
import time

  from ._conv import register_converters as _register_converters


* 定义一个类别预测层，其中通道数表示预测的类别，输出特征图高和宽不变，通道数为 每个像素锚框个数x（类别个数+1）

In [3]:
#输入的num_anchors表示每个像素的锚框数
def cls_pred(num_anchors,num_classes):
    return nn.Conv2D(channels=num_anchors*(num_classes+1),kernel_size=3,padding=1)


In [4]:
def bbox_pred(num_anchors):
    return nn.Conv2D(channels=num_anchors*4,kernel_size=3,padding=1)

* 定义一个合并输出的函数

In [5]:
def flatten_pred(pred):
    return pred.transpose((0,2,3,1)).flatten()

def concat_preds(preds):
    return nd.concat(*[flatten_pred(pred) for pred in preds],dim=1)

* 定义降采样模块

In [6]:
def down_sample_blk(num_channels):
    blk = nn.Sequential()
    for _ in range(2):
        blk.add(nn.BatchNorm(),nn.Activation('relu'),
               nn.Conv2D(num_channels,kernel_size=3,padding=1))
    #最后接一个最大池化层
    blk.add(nn.MaxPool2D(pool_size=2,strides=2))
    return blk

In [7]:
def forward(x, block):
    block.initialize()
    return block(x)
forward(nd.zeros((2, 3, 20, 20)), down_sample_blk(10)).shape

(2, 10, 10, 10)

* 定义基础网络块，用来抽取特征

In [8]:
#可以自己定义比如说resnet等等
def base_net():
    blk = nn.Sequential()
    for num_filters in [16,32,64]:
        blk.add(down_sample_blk(num_filters))
    return blk
forward(nd.zeros((2, 3, 256, 256)), base_net()).shape

(2, 64, 32, 32)

## 完整的SSD模型
* 基础网络块
* 3个高和宽减半模块
* 全局最大池化层
每个模块都会生成锚框并且预测边界框回归和类别

In [9]:
def get_blk(i):
    if i==0:
        return base_net()
    elif i==4:
        return nn.GlobalMaxPool2D()
    else:
        return down_sample_blk(128)

## 定义SSD的运算

In [10]:
def blk_forward(X,blk,sizes,ratios,cls_predictor,bbox_predictor):
    #定义SSD中前向运算的函数
    #生成锚框
    #前向运算
    Y = blk(X)                    #计算下一层的输出 （批量大小，通道数，高，宽）
    anchors = contrib.ndarray.MultiBoxPrior(Y, sizes=sizes, ratios=ratios)
    cls_preds = cls_predictor(Y)   #预测类别 （批量大小，锚框个数*（类别数+1），高，宽）
    bbox_preds = bbox_predictor(Y) #预测边界框回归 （批量大小，锚框个数*4，高，宽）
    
    #进入下一层的运算
    return (Y,anchors,cls_preds,bbox_preds)
        
    

In [11]:
class TinySSD(nn.Block):
    def __init__(self,num_classes,**kwargs):
        
        super(TinySSD,self).__init__(**kwargs)
        #定义网络结构
        self.num_classes = num_classes
        #定义每一层的宽高比和锚框个数
        self.sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],
                        [0.88, 0.961]]
        self.ratios = [[1, 2, 0.5]] * 5
        #定义每一层的的网络结构,特征层+预测类别+预测边界框
        for i in range(5):
            num_anchors_per_pixel = len(self.sizes[i])+len(self.ratios[i])-1
            setattr(self,'blk_%d'%i,get_blk(i))
            setattr(self,'cls_predictor_%d'%i,cls_pred(num_anchors_per_pixel,self.num_classes))
            setattr(self,'bbox_predictor_%d'%i,bbox_pred(num_anchors_per_pixel))
    
    def forward(self,X):
        #定义前向运算，每一个都会输出
        anchors,cls_preds,bbox_preds=[],[],[]
        for i in range(5):
            #前向运算
            X,anchor,cls_pred,bbox_pred = blk_forward(X,getattr(self,'blk_%d' % i),self.sizes[i],self.ratios[i],
                                                     getattr(self,'cls_predictor_%d' %i),getattr(self,'bbox_predictor_%d'%i))
            anchors.append(anchor)
            cls_preds.append(cls_pred)
            bbox_preds.append(bbox_pred)
        
        #返回输出
        return (nd.concat(*anchors,dim=1),
                concat_preds(cls_preds).reshape((0,-1,self.num_classes+1)),
                concat_preds(bbox_preds))


#


## 下面开始进行模型训练

* 首先是获得训练的数据集

In [12]:
#获取皮卡丘数据
batch_size = 8
train_iter,test_iter = gb.load_data_pikachu(batch_size)

#定义优化器
ctx = gb.try_gpu()
ssd = TinySSD(num_classes=1)
#使用gpu进行训练
ssd.initialize(init = init.Xavier(),ctx=ctx,force_reinit = True)
trainer = gluon.Trainer(ssd.collect_params(),'sgd',{'learning_rate':0.05,'wd':5e-4})

In [14]:
train_iter.reshape(label_shape=(3,5))

### 下面定义评价准确度的函数，以及损失函数

* 定义bbox的损失函数bbox_L1Loss

In [15]:
cls_loss = gloss.SoftmaxCrossEntropyLoss()
bbox_loss = gloss.L1Loss()

In [16]:
L1Loss = gloss.L1Loss()
def bbox_L1Loss(bbox_preds,bbox_tag,masks):
    #其中负类锚框的损失函数不要计算
    return L1Loss(bbox_preds*masks,bbox_tag*masks)

* 定义预测类别的损失函数cls_loss

In [17]:
def cal_total_loss(bbox_preds,bbox_labels,bbox_masks,cls_preds,cls_labels):
    cls = cls_loss(cls_preds,cls_labels)
    bbox = bbox_loss(bbox_preds * bbox_masks, bbox_labels * bbox_masks)
    return cls+bbox

* 定义评价精度的函数

In [18]:
def eval_cls_accuracy(cls_preds,cls_labels):
    return (cls_preds.argmax(axis=-1)==cls_labels).mean().asscalar()   

### 定义训练函数

In [19]:
def train(num_epochs):
    for epoch in range(num_epochs):
        #由于不是DataLoader类，每次需要我们重置指针
        train_iter.reset()
        start = time.time()
        train_cls_accuracy = 0
        train_MAE = 0
        train_cls_loss = 0
        #test_iter.reset()
        #每次取出一个小批量数据进行训练
        for i,batch in enumerate(train_iter):
            X = batch.data[0].as_in_context(ctx)
            Y = batch.label[0].as_in_context(ctx)
            #下面进行前向运算
            with autograd.record():
                #计算网络的输出
                anchors,cls_preds,bbox_preds = ssd(X)
#                 print(anchors.shape)
#                 print(Y.shape)
#                 print(cls_preds.shape)
                #生成真实标记
                #返回偏移量，掩码，类别标签
                bbox_offsets,bbox_masks,cls_labels=contrib.nd.MultiBoxTarget(anchors,Y,cls_preds.transpose((0,2,1)))
                #下面计算损失函数
#                 print(cls_preds.shape,cls_labels.shape)
#                 break
                l_cls=cls_loss(cls_preds,cls_labels)
                l_bbox = bbox_L1Loss(bbox_preds,bbox_offsets,bbox_masks)
                l = l_cls+l_bbox
                
            #反向传播
            l.backward()
            #迭代参数
            trainer.step(batch_size)      
            #记录准确率
            #print(l_cls)
            train_cls_loss += l_cls.mean().asscalar()
            train_MAE += l_bbox.mean().asscalar()
            train_cls_accuracy += eval_cls_accuracy(cls_preds,cls_labels)
        if (epoch+1)%5==0:
            print('epoch %2d , class acc %.2f,class err %.2e,bbox mae %.2e,time %.1f sec'
                  %epoch+1,train_cls_accuracy/(i+1),train_cls_loss/(i+1),train_MAE/(i+1),time.time()-start)
            

In [None]:
train(1)