# Deep Residual Neural Network Pre Activation

论文地址： [Identity Mappings in Deep Residual Networks](https://arxiv.org/abs/1603.05027)

In [5]:
import mxnet as mx

from mxnet import nd
from mxnet import gluon

In [55]:
def BN_ReLU_Conv(num_filter, kernel_size, strides=1, padding=0, erase_relu=False, 
                 erase_conv=False, use_bias=False):
    '''
    num_filter : Int
        The output channel of this unit
    kernel_size : int
        Kernel size of the Conv kernel
    strides : int or tuple
        Conv strides
    padding : int or tuple
        Conv padding
    erase_relu : Boolean
        whether erase relu
    use_bias : Boolean
        Whether use bias
    '''
    unit = gluon.nn.HybridSequential()
    with unit.name_scope():
        unit.add(gluon.nn.BatchNorm(axis=1))
        if not erase_relu:
            unit.add(gluon.nn.Activation('relu'))
        if not erase_conv:
            unit.add(gluon.nn.Conv2D(num_filter, kernel_size=kernel_size, strides=strides, 
                                 padding=padding, use_bias=use_bias))
    return unit

def BN_ReLU(erase_relu=False):
    '''
    erase_relu : Boolean
        whether erase relu
    '''
    unit = gluon.nn.HybridSequential()
    with unit.name_scope():
        unit.add(gluon.nn.BatchNorm(axis=1))
        if not erase_relu:
            unit.add(gluon.nn.Activation('relu'))
    return unit

class Residual_Unit(gluon.nn.HybridBlock):
    def __init__(self, num_filter, strides, dim_match=True, bottle_neck=False, **kwargs):
        super().__init__(**kwargs)
        self.dim_match = dim_match
        self.bottle_neck = bottle_neck
        residual = self.residual = gluon.nn.HybridSequential()
        with self.name_scope():
            if bottle_neck:
                residual.add(BN_ReLU_Conv(int(num_filter * 0.25), kernel_size=1))  
                residual.add(BN_ReLU_Conv(int(num_filter * 0.25), kernel_size=3, strides=strides, padding=1))
                residual.add(BN_ReLU_Conv(num_filter, kernel_size=1))
            else:
                residual.add(BN_ReLU_Conv(num_filter, kernel_size=3, strides=strides, padding=1))
                residual.add(BN_ReLU_Conv(num_filter, kernel_size=3, strides=1, padding=1))
            # add this BN layer
            residual.add(gluon.nn.BatchNorm(axis=1))
        if not dim_match:
            self.conv1 = gluon.nn.Conv2D(num_filter, kernel_size=1, strides=strides, use_bias=False)
            self.bn_relu = BN_ReLU()
    
    def hybrid_forward(self, F, X):
        if self.dim_match:
            shortcut = X
        else:
            shortcut = self.conv1(self.bn_relu(X))
        output = self.residual(X)
        return output + shortcut

class ResNet(gluon.nn.HybridBlock):
    def __init__(self, unit_list, filter_list, num_classes, data_type, bottle_neck=False, debug=False, **kwargs):
        '''
        unit_list : list
            num of unit per stage
        filter_list : list
            channnel size of each stage 
        num_classes : int
            num of classes
        bottle_neck : bool
            the bottle_neck archeitecture
        debug : bool
            print the shape of each stage 
        '''
        super().__init__(**kwargs)
        num_stage = len(unit_list)
        self.debug = debug
    
        resnet = self.resnet = gluon.nn.HybridSequential()
        with self.name_scope():
            # first layer 
            blk1 = gluon.nn.HybridSequential()
            blk1.add(gluon.nn.BatchNorm(axis=1))
            if data_type == 'cifar10':
                blk1.add(gluon.nn.Conv2D(filter_list[0], kernel_size=3, padding=1, use_bias=False))
            elif data_type == 'imagenet':
                blk1.add(
                    gluon.nn.Conv2D(filter_list[0], kernel_size=7, strides=2, padding=3, use_bias=False),
                    gluon.nn.BatchNorm(axis=1),
                    gluon.nn.Activation('relu'),
                    gluon.nn.MaxPool2D(pool_size=3, strides=2, padding=1)
                )
            else:
                raise ValueError('do not support {} yet.'.format(data_type))
            resnet.add(blk1)
            
            # residual layer 
            for i in range(num_stage):
                blk2 = gluon.nn.HybridSequential()
                blk2.add(Residual_Unit(filter_list[i+1], 1 if i==0 else 2, False, 
                                      bottle_neck=bottle_neck))
                for j in range(unit_list[i] - 1):
                    blk2.add(Residual_Unit(filter_list[i+1], 1, True, bottle_neck=bottle_neck))
                resnet.add(blk2)
            
            # regularizaiton layer 
            blk3 = gluon.nn.HybridSequential()
            blk3.add(
                gluon.nn.BatchNorm(axis=1),
                gluon.nn.Activation('relu')
            )
            resnet.add(blk3)
            
            # classification layer 
            blk4 = gluon.nn.HybridSequential()
            blk4.add(
                gluon.nn.GlobalAvgPool2D(),
                gluon.nn.Dense(num_classes)
            )
            resnet.add(blk4)

    def hybrid_forward(self, F, X):
        out = X
        for i, blk in enumerate(self.resnet):
            out = blk(out)
            if self.debug:
                print('blk {} : {}'.format(i+1, out.shape))
        return out

In [62]:
depth = 164

if (depth-2)%9 == 0 and depth >= 164:
    per_unit = [int((depth-2)/9)]
    filter_list = [16, 64, 128, 256]
    bottle_neck = True
elif (depth-2)%6 == 0 and depth < 164:
    per_unit = [int((depth-2)/6)]
    filter_list = [16, 16, 32, 64]
    bottle_neck = False
else:
    raise ValueError("depth {} error.".format(depth))
unit_list = per_unit * 3

symbol = ResNet(unit_list=unit_list, filter_list=filter_list, num_classes=10, 
                data_type='cifar10', bottle_neck=bottle_neck, debug=True)
symbol.initialize()
sample = nd.random.uniform(shape=(128, 3, 32, 32))
y = symbol(sample)

blk 1 : (128, 16, 32, 32)
blk 2 : (128, 64, 32, 32)
blk 3 : (128, 128, 16, 16)
blk 4 : (128, 256, 8, 8)
blk 5 : (128, 256, 8, 8)
blk 6 : (128, 10)
