In [2]:
#单纯只是一种block
import numpy as np
import torch, torchvision
from torch import nn 
import mmcv
from mmcv import Config
from mmcv.cnn import ConvModule

In [None]:
#3*3卷积  ：kernel_size=3,strides=1,padding=1
#1*1卷积  ：kernel_size=1,strides=1,padding=1
#1*7卷积  ：kernel_size=(1,7),strides=1,padding=(0,3)
#7*1卷积  ：kernel_size=(7,1),strides=1,padding=(3,0)
#d=5      : kernel_size=3,strides=1,padding=5,dilation=5

In [3]:
class RFBBlock(nn.Module):
    """
    Receptive Field Block Net For Accurate and Fase Object Detection 
    Args:
        in_channels(int):input feature 
        reduction (int):reduction of feature num
    Returns:
        output(tensor):same as input
    """
    def __init__(self,in_channels,
                 reduction = 4,
                conv_cfg={'type':'Conv2d'},
                norm_cfg=None,
                act_cfg={'type': 'ReLU'},):
        super().__init__()
        self.in_channels = in_channels
        assert self.in_channels%reduction == 0
        mid_channels = self.in_channels//reduction
        self.b0 = nn.Sequential(ConvModule(self.in_channels,
                                           mid_channels,
                                          1,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                               nn.Conv2d(mid_channels,
                                        mid_channels,
                                        3,
                                        1,1))
        self.b1 = nn.Sequential(ConvModule(self.in_channels,
                                           mid_channels,
                                          kernel_size=1,
                                           stride=1,
                                           padding=0,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                               ConvModule(mid_channels,
                                           mid_channels,
                                          3,
                                          1,1,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                               ConvModule(mid_channels,
                                           mid_channels,
                                          3,
                                          stride=1,
                                          padding=3,
                                          dilation=3,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),)
        self.b2= nn.Sequential(ConvModule(self.in_channels,
                                           mid_channels,
                                          1,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                              ConvModule(mid_channels,
                                           mid_channels,
                                          kernel_size=3,
                                         stride=1,
                                         padding=1,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                              ConvModule(mid_channels,
                                           mid_channels,
                                          kernel_size=3,
                                         stride=1,
                                         padding=1,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                              ConvModule(mid_channels,
                                           mid_channels,
                                          kernel_size=3,
                                         stride=1,
                                         padding=5,
                                         dilation = 5,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),)
        self.b3 =nn.Sequential(ConvModule(self.in_channels,
                                           mid_channels,
                                          1,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                              ConvModule(mid_channels,
                                           mid_channels,
                                          kernel_size=(1,7),
                                         stride=1,
                                         padding=(0,3),
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                              ConvModule(mid_channels,
                                           mid_channels,
                                          kernel_size=(7,1),
                                         stride=1,
                                         padding=(3,0),
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),
                              ConvModule(mid_channels,
                                           mid_channels,
                                          3,
                                         stride=1,
                                         padding=7,
                                         dilation= 7,
                                          conv_cfg=conv_cfg,
                                          act_cfg= act_cfg,
                                          ),)
        self.conv = nn.Conv2d(self.in_channels,self.in_channels,1,1,0)
        self.relu =nn.ReLU(True)
    def forward(self,x):
        x1 = self.b0(x)

        x2 = self.b1(x)

        x3 = self.b2(x)

        x4 = self.b3(x)

        out = self.conv(torch.cat([x1,x2,x3,x4],dim=1))

        return self.relu(out+x)

In [4]:
a =torch.randn(1,128,10,10)
test= RFBBlock(128)

In [6]:
test.eval()
torch.onnx.export(test,a,'rfb.onnx')