In [1]:
from mxnet import gluon, nd
from mxnet.gluon import nn
import mxnet as mx

class Residual(nn.Block):
    '''
    simple residual module
    '''
    def __init__(self, num_channels, change_shape=False, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.change_shape = change_shape
        strides = 1 if not change_shape else 2
        
        # conv use 3*3 filter, use strides change shape if necessery
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, strides=strides, padding=1)
        self.bn1 = nn.BatchNorm()
        
        # conv use 3*3 filter, only get more features
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm()
        
        # in order to plus input and output, change input shape if output shape has changed 
        if change_shape:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1, strides=strides)
        
    def forward(self, X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.change_shape:
            X = self.conv3(X)
        Y = nd.relu(Y + X)
        return Y

    
def resnet_block(num_channels, num_residuals, first_block=False):
    '''
    define resnet block
    '''
    blk = nn.Sequential()
    for i in range(num_residuals):
        if i == 0 and not first_block:
            # if first conv in block, need change shape
            blk.add(Residual(num_channels, change_shape=True))
        else:
            blk.add(Residual(num_channels))
    return blk


def Resnet_18():
    '''
    define resnet 18
    '''
    net = nn.Sequential()
    net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
            nn.BatchNorm(), nn.Activation('relu'),
            nn.MaxPool2D(pool_size=3, strides=2, padding=1))
    
    net.add(resnet_block(64, 2, first_block=True),
            resnet_block(128, 2),
            resnet_block(256, 2),
            resnet_block(512, 2))
    
    net.add(nn.GlobalAvgPool2D(), nn.Dense(10))
    
    return net


def Resnet_34():
    '''
    define resnet 34
    '''
    net = nn.Sequential()
    net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
            nn.BatchNorm(), nn.Activation('relu'),
            nn.MaxPool2D(pool_size=3, strides=2, padding=1))
    
    net.add(resnet_block(64, 3, first_block=True),
            resnet_block(128, 4),
            resnet_block(256, 6),
            resnet_block(512, 3))
    
    net.add(nn.GlobalAvgPool2D(), nn.Dense(10))
    
    return net