In [1]:
import mxnet as mx
from mxnet import gluon,autograd,contrib,image,nd
from mxnet.gluon import data as gdata,loss as gloss,nn
import gluoncv
import matplotlib as mpl
import matplotlib.pyplot as plt
import sys
import time

  from ._conv import register_converters as _register_converters


## 首先定义残差小块

In [10]:
#定义Bottleneck用于搭建 resnet50
def _conv3x3(channels,strides=1,in_channels=0):
    return nn.Conv2D(channels=channels,kernel_size=3,strides=strides,padding=1,in_channels=in_channels)        

In [20]:
class Bottleneck(nn.Block):
    def __init__(self,channels,down_sample=True,strides=1,in_channels=0,**kwargs):
        super(Bottleneck,self).__init__(**kwargs)
        self.bn1 = nn.BatchNorm()
        self.conv1 = nn.Conv2D(channels//4,kernel_size=1,strides=1,use_bias=False)
        self.bn2 = nn.BatchNorm()
        self.conv2 = _conv3x3(channels//4,strides,channels//4)
        self.bn3 = nn.BatchNorm()
        self.conv3 = nn.Conv2D(channels,kernel_size=1,strides=1,use_bias=False)
        
        if down_sample:
            self.downsample = nn.Conv2D(channels,1,strides,use_bias=False,
                                       in_channels=in_channels)
        else:
            self.downsample = None
            
    def forward(self,x):
        residual = x
        x = self.bn1(x)
        x = nd.Activation(x,act_type='relu')
        if self.downsample:
            residual = self.downsample(x)
        x = self.conv1(x)
        
        x = self.bn2(x)
        x = nd.Activation(x,act_type='relu')
        x = self.conv2(x)
        
        x = self.bn3(x)
        x = nd.Activation(x,act_type = 'relu')
        x = self.conv3(x)
        
        return x+residual

In [34]:
def netforward(net):
    x = nd.random.uniform(shape=(1,3,512,512))
    net.initialize(force_reinit = True)
    
    print("x shape ",x.shape)
    print('output shape',net(x).shape)

In [22]:
res_blk = Bottleneck(64)
netforward(res_blk)

x shape  (1, 3, 256, 256)
output shape (1, 64, 256, 256)


##  定义resnet50

In [24]:
#第一个channel为前面 7x7的输出
num_blks = [3,4,6,3]
num_channels = [64,256,512,1024,2048]

In [49]:
class ResNetV2(nn.Block):
    
    def __init__(self,block,layers,channels,**kwargs):
        super(ResNetV2,self).__init__(**kwargs)
        assert len(layers) == len(channels) - 1
        
        self.features = nn.Sequential()
        self.features.add(nn.BatchNorm(scale=False,center=False))
        
        self.features.add(nn.Conv2D(channels[0],7,2,3,use_bias=False))
        self.features.add(nn.BatchNorm())
        self.features.add(nn.Activation('relu'))
        self.features.add(nn.MaxPool2D(3,2,1))
        
        in_channels = channels[0]
        for i,num_blk in enumerate(layers):
            stride = 1 if i==0 or i==len(layers)-1 else 2
            self.features.add(self._make_layer(block,num_blk,channels[i+1],
                                              stride,in_channels=in_channels
                                              ))
            in_channels = channels[i+1]
    def _make_layer(self,block,layers,channels,stride,in_channels=0):
        
        layer = nn.Sequential()
        #每个残差块第一个进行降采样
        layer.add(block(channels,channels!=in_channels,stride,in_channels=in_channels))
        
        for _ in range(layers-1):
            layer.add(block(channels,False,1,in_channels=channels))
        
        return layer
    
    def forward(self,x):
        
        return self.features(x)
            
        
        

In [50]:
resnet50 = ResNetV2(Bottleneck,num_blks,num_channels)

In [51]:
netforward(resnet50)

x shape  (1, 3, 512, 512)
output shape (1, 2048, 32, 32)
