In [13]:
import torch
import torch.nn as nn

# separable convolution: https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py#L224
# dense block: https://github.com/jacobkimmel/fcdensenet_pytorch/blob/master/model.py

"""
每个unit，自动cancat其输入与输出
"""
class dense_block_unit(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(dense_block_unit, self).__init__()
        self.conv0 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, dilation=1, groups=in_channel, bias=False)
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        self.act = nn.ReLU(inplace=True)
    
    def forward(self, inputs):
        output = self.conv0(inputs)
        output = self.conv1(output)
        output = self.act(self.bn(output))
        concat = torch.cat([inputs, output], dim=1)   # 自动 cancat
        
        return concat

class dense_block(nn.Module):
    def __init__(self, num_units, in_channel, growth_rate):
        super(dense_block, self).__init__()
        self.in_channel = in_channel
        self.growth_rate = growth_rate
        
        layers = []
        for i in range(num_units):
            layers.append(dense_block_unit(in_channel + i*growth_rate, growth_rate))
            
        self.layers = nn.Sequential(*layers)
        
    def forward(self, inputs):
        output = self.layers(inputs)
        
        return output
            
        

In [17]:
tiny_model = dense_block(4, 32, 16)

In [16]:
inputs = torch.randn(3, 32, 64, 64)
output= tiny_model(inputs)
print(output.shape)               # output channel = input channel + (num_units-1)*growth_rate
                                  # i.e. 32 + 4*16 = 96

torch.Size([3, 96, 64, 64])


In [20]:
"""
构建 module list，前向 forward 时，手动 cancat 当前层结果作为下一层输入
"""

class dense_block_unit(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(dense_block_unit, self).__init__()
        self.conv0 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, dilation=1, groups=in_channel, bias=False)
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        self.act = nn.ReLU(inplace=True)
    
    def forward(self, inputs):
        output = self.conv0(inputs)
        output = self.conv1(output)
        output = self.act(self.bn(output))

        return output

    
class dense_block(nn.Module):
    def __init__(self, num_units, in_channel, growth_rate):
        super(dense_block, self).__init__()
        self.in_channel = in_channel
        self.growth_rate = growth_rate
        
        # 构建 module list
        self.layers = nn.ModuleList()
        for i in range(num_units):
            self.layers.append(dense_block_unit(in_channel + i*growth_rate, growth_rate))
        
    def forward(self, inputs):
        for layers in self.layers:
            output = layers(inputs)
            inputs = torch.cat([inputs, output], dim=1)        # 前向传播时，手动cancat结果
        
        return inputs

In [21]:
tiny_model = dense_block(4, 32, 16)
inputs = torch.randn(3, 32, 64, 64)
output= tiny_model(inputs)
print(output.shape)               # output channel = input channel + (num_units-1)*growth_rate
                                  # i.e. 32 + 4*16 = 96

torch.Size([3, 96, 64, 64])
