In [1]:
from mxnet import gluon,init,nd
from mxnet.gluon import nn

# 1.Res_block
---

In [23]:
class Residual(nn.Block):
    def __init__(self,num_channels,use_1x1conv = False,strides = 1,**kwargs):
        super(Residual,self).__init__(**kwargs)
        
        self.conv1 = nn.Conv2D(num_channels,kernel_size=3,padding=1,strides=strides)
        self.conv2 = nn.Conv2D(num_channels,kernel_size=3,padding=1)
        
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels,kernel_size=1,strides=strides)
        else:
            self.conv3 = None
            
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()
    
    def forward(self,X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return nd.relu(Y + X)

In [28]:
# 浅层
blk = Residual(3)
blk.initialize()
X = nd.random.uniform(shape=(8,3,12,12))
blk(X).shape

(8, 3, 12, 12)

In [29]:
blk = Residual(6,use_1x1conv=True,strides=2)
blk.initialize()
blk(X).shape

(8, 6, 6, 6)

In [31]:
def resnet_block(num_channels,num_residual,first_block = False):
    blk = nn.Sequential()
    for i in range(num_residual):
        if i == 0 and not first_block:
            blk.add(Residual(num_channels,use_1x1conv=True,strides=2))
        else:
            blk.add(Residual(num_channels))
    return blk

In [30]:
net = nn.Sequential()
net.add(nn.Conv2D(64,kernel_size=7,strides=2,padding=3),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3,strides=2,padding=1))

In [32]:
net.add(resnet_block(64,2,first_block=True), 
       resnet_block(128,2),
       resnet_block(256,2),
       resnet_block(512,2))

In [33]:
net.add(nn.GlobalAvgPool2D(),
        nn.Dense(10))