In [1]:
import mxnet as mx
import numpy as np

from mxnet import nd
from mxnet import gluon
from mxnet.gluon import nn

mx.random.seed(1)
ctx = mx.gpu()

## DenseNet-BC

In [2]:
def BN_ReLU_Conv(num_filters, kernel_size, strides=1, padding=0, erase_relu=False):
    '''
    num_filters : int
        the output channel of this module
    kernel_size : int or tuple
        conv kernel
    strides : int or tuple
        conv strides
    padding : int or tuple
        conv padding
    erase_relu : bool
        whethr erase the relu activation
    '''
    net = gluon.nn.HybridSequential()
    with net.name_scope():
        net.add(gluon.nn.BatchNorm(axis=1, epsilon=2e-5))
        if erase_relu:
            net.add(gluon.nn.Activation('relu'))
        net.add(gluon.nn.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, 
                                padding=padding, use_bias=False))
    return net

'''
BasicBlock
'''
class BasicBlock(gluon.nn.HybridBlock):
    def __init__(self, growth_rate, strides, bottleneck=False, drop_rate=0.0, **kwargs):
        '''
        growth_rate : int
        strides : step
        bottleneck : bool
        drop_rate : float
        '''
        super().__init__(**kwargs)
        blk = self.blk = gluon.nn.HybridSequential()
        with self.name_scope():
            if bottleneck:
                blk.add(BN_ReLU_Conv(int(growth_rate * 4), kernel_size=1))
                if drop_rate > 0.0:
                    blk.add(gluon.nn.Dropout(drop_rate))
                blk.add(BN_ReLU_Conv(int(growth_rate), kernel_size=3, strides=strides, padding=1))
                if drop_rate > 0.0:
                    blk.add(gluon.nn.Dropout(drop_rate))
            else:
                blk.add(BN_ReLU_Conv(int(growth_rate), kernel_size=3, strides=1))
                if drop_rate > 0.0:
                    blk.add(gluon.nn.Dropout(drop_rate))
                    
    def hybrid_forward(self, F, X):
        out = self.blk(X)
        return out
    
'''
DenseBlock
'''
class DenseBlock(gluon.nn.HybridBlock):
    def __init__(self, num_units, growth_rate, strides=1, bottleneck=False, drop_rate=0.0, **kwargs):
        '''
        num_units : int 
            each denseblock have num_units basicblock
        growth_rate : int
            the output channel of the 
        '''
        super().__init__(**kwargs)
        net = self.net = gluon.nn.HybridSequential()
        with self.name_scope():
            for i in range(num_units):
                net.add(BasicBlock(growth_rate, strides=strides, bottleneck=bottleneck, drop_rate=drop_rate))
    
    def hybrid_forward(self, F, X):
        for i, blk in enumerate(self.net):
            out = blk(X)
            X = F.concat(X, out, dim=1)
        return X
    
'''
TransitionBlock
'''
class TransitionBlock(gluon.nn.HybridBlock):
    def __init__(self, num_filter, strides=1, drop_rate=0.0, **kwargs):
        super().__init__(**kwargs)
        trans = self.trans = gluon.nn.HybridSequential()
        with self.name_scope():
            trans.add(BN_ReLU_Conv(num_filter, kernel_size=1, strides=strides))
            if drop_rate > 0.0:
                trans.add(gluon.nn.Dropout(drop_rate))
            trans.add(gluon.nn.AvgPool2D(pool_size=2, strides=2))
    
    def hybrid_forward(self, F, X):
        out = self.trans(X)
        return out
    
class DenseNet(gluon.nn.HybridBlock):
    def __init__(self, depth, num_stages, growth_rate, num_classes, data_type, reduction=0.5, 
                 bottleneck=False, drop_rate=0.0, debug=False, **kwargs):
        super().__init__(**kwargs)
        self.debug = debug
        
        if (depth - 4) % 3:
            raise Exception("Invalid Error!")
        
        if data_type == 'cifar':
            units = [(depth - 4) // 6 if bottleneck else 3] * num_stages
        else:
            raise ValueError('Not Implementation {} yet.'.format(data_type))
            
        num_units = len(units)
        assert(num_units == num_stages)
        n_channels = init_channels = 16 if data_type=="cifar" else 2 * growth_rate
        densenet = self.densenet = gluon.nn.HybridSequential()
        with self.name_scope():
            blk1 = gluon.nn.HybridSequential()
            blk1.add(gluon.nn.BatchNorm(axis=1, epsilon=2e-5))
            if data_type == 'cifar':
                blk1.add(gluon.nn.Conv2D(init_channels, kernel_size=3, padding=1, use_bias=False))
            elif data_type == 'imagenet':
                blk1.add(gluon.nn.Conv2D(init_channels, kernel_size=7, strides=2, padding=3, use_bias=False))
                blk1.add(gluon.nn.BatchNorm(axi=1, epsilon=2e-5))
                blk1.add(gluon.nn.Activation('relu'))
                blk1.add(gluon.nn.MaxPool2D(pool_size=3, strides=2, padding=1))
            densenet.add(blk1)
    
            for i in range(num_stages-1):
                blk2 = gluon.nn.HybridSequential()
                blk2.add(DenseBlock(units[i], growth_rate, bottleneck=bottleneck, drop_rate=drop_rate))
                # count channels to prepare for the transition block
                n_channels += units[i] * growth_rate
                n_channels = int(np.floor(n_channels * reduction))
                blk2.add(TransitionBlock(n_channels, strides=1, drop_rate=drop_rate))
                densenet.add(blk2)
            densenet.add(DenseBlock(units[num_stages-1], growth_rate, bottleneck=bottleneck, drop_rate=drop_rate))
            
            blk3 = gluon.nn.HybridSequential()
            blk3.add(
                gluon.nn.BatchNorm(axis=1, epsilon=2e-5),
                gluon.nn.Activation('relu')
            )
            densenet.add(blk3)
            
            blk4 = gluon.nn.HybridSequential()
            blk4.add(
                gluon.nn.GlobalAvgPool2D(),
                gluon.nn.Dense(num_classes)
            )
            densenet.add(blk4)
            
    def hybrid_forward(self, F, X):
        out = X
        for i, blk in enumerate(self.densenet):
            out = blk(out)
            if self.debug:
                print("blk {} : {}".format(i+1, out.shape))
        return out

In [26]:
symbol = DenseNet(depth=100, num_stages=3, growth_rate=12, num_classes=10, data_type='cifar', 
                  bottleneck=True, drop_rate=0.2, debug=True)
symbol.collect_params().initialize()

In [27]:
sample = nd.random.normal(shape=(1,3,32,32))
y = symbol(sample)

blk 1 : (1, 16, 32, 32)
blk 2 : (1, 104, 32, 32)
blk 3 : (1, 148, 32, 32)
blk 4 : (1, 340, 32, 32)
blk 5 : (1, 340, 32, 32)
blk 6 : (1, 10)


In [28]:
symbol

DenseNet(
  (densenet): HybridSequential(
    (0): HybridSequential(
      (0): BatchNorm(momentum=0.9, eps=2e-05, axis=1, fix_gamma=False, in_channels=3)
      (1): Conv2D(3 -> 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): HybridSequential(
      (0): DenseBlock(
        (net): HybridSequential(
          (0): BasicBlock(
            (blk): HybridSequential(
              (0): HybridSequential(
                (0): BatchNorm(momentum=0.9, eps=2e-05, axis=1, fix_gamma=False, in_channels=16)
                (1): Conv2D(16 -> 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
              )
              (1): Dropout(p = 0.2)
              (2): HybridSequential(
                (0): BatchNorm(momentum=0.9, eps=2e-05, axis=1, fix_gamma=False, in_channels=48)
                (1): Conv2D(48 -> 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              )
              (3): Dropout(p = 0.2)
            )
          )
          (1): Bas