In [6]:
import torch
import torch.nn as nn

# separable convolution: https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py#L224
# dense block: https://github.com/jacobkimmel/fcdensenet_pytorch/blob/master/model.py

"""
每个unit，自动cancat其输入与输出
"""
class dense_block_unit(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(dense_block_unit, self).__init__()
        self.conv0 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, dilation=1, groups=in_channel, bias=False)
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        self.act = nn.ReLU(inplace=True)
    
    def forward(self, inputs):
        output = self.conv0(inputs)
        output = self.conv1(output)
        output = self.act(self.bn(output))
        concat = torch.cat([inputs, output], dim=1)   # 自动 cancat
        
        return concat

class dense_block(nn.Module):
    def __init__(self, num_units, in_channel, growth_rate, keep_input=True):
        super(dense_block, self).__init__()
        self.in_channel = in_channel
        self.growth_rate = growth_rate
        self.keep_input = keep_input
        
        if keep_input:
            self.out_channel = self.in_channel + num_units*growth_rate
        else:
            self.out_channel = num_units*growth_rate
            
        layers = []
        for i in range(num_units):
            layers.append(dense_block_unit(in_channel + i*growth_rate, growth_rate))
            
        self.layers = nn.Sequential(*layers)
        
    def forward(self, inputs):
        output = self.layers(inputs)
        
        if not self.keep_input:
            output = output[:, self.in_channel:, ...]
        
        return output
            
        

In [17]:
tiny_model = dense_block(4, 32, 16)

In [16]:
inputs = torch.randn(3, 32, 64, 64)
output= tiny_model(inputs)
print(output.shape)               # output channel = input channel + (num_units-1)*growth_rate
                                  # i.e. 32 + 4*16 = 96

torch.Size([3, 96, 64, 64])


In [4]:
"""
构建 module list，前向 forward 时，手动 cancat 当前层结果作为下一层输入
"""

class dense_block_unit(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(dense_block_unit, self).__init__()
        self.conv0 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, dilation=1, groups=in_channel, bias=False)
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        self.act = nn.ReLU(inplace=True)
    
    def forward(self, inputs):
        output = self.conv0(inputs)
        output = self.conv1(output)
        output = self.act(self.bn(output))

        return output

    
class dense_block(nn.Module):
    def __init__(self, num_units, in_channel, growth_rate, keep_input=True):
        super(dense_block, self).__init__()
        self.in_channel = in_channel
        self.growth_rate = growth_rate
        self.keep_input = keep_input
        
        if keep_input:
            self.out_channel = self.in_channel + num_units*growth_rate
        else:
            self.out_channel = num_units*growth_rate
        
        # 构建 module list
        self.layers = nn.ModuleList()
        for i in range(num_units):
            self.layers.append(dense_block_unit(in_channel + i*growth_rate, growth_rate))
        
    def forward(self, inputs):
        for layers in self.layers:
            output = layers(inputs)
            inputs = torch.cat([inputs, output], dim=1)        # 前向传播时，手动cancat结果
        
        if not self.keep_input:
            inputs = inputs[:, self.in_channel:, ...]
            
        return inputs

In [21]:
tiny_model = dense_block(4, 32, 16)
inputs = torch.randn(3, 32, 64, 64)
output= tiny_model(inputs)
print(output.shape)               # output channel = input channel + (num_units-1)*growth_rate
                                  # i.e. 32 + 4*16 = 96

torch.Size([3, 96, 64, 64])


In [1]:
# 另外一种构建模型的方式， 直接继承 nn.Sequential()

# https://github.com/bfortuner/pytorch_tiramisu/blob/master/models/layers.py

class DenseLayer(nn.Sequential):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.add_module('norm', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(True))
        self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3,
                                          stride=1, padding=1, bias=True))
        self.add_module('drop', nn.Dropout2d(0.2))

    def forward(self, x):
        return super().forward(x)

SyntaxError: invalid character in identifier (<ipython-input-1-bf96525c5634>, line 1)

In [None]:
# 堆加 dense block 的一种写法
class Densenet_tiny(nn.Module):
    def __init__(self, in_channel, first_conv_out, growth_rate, down_blocks=(4, 4), bottle_block=4,
                up_blocks=(4, 4)):
        super(Densenet_tiny, self).__init__()
        self.down_blocks = down_blocks
        self.up_blocks = up_blocks
        self.conv0 = nn.Conv2d(in_channel, first_conv_out, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(first_conv_out)
        self.act = nn.ReLU(inplace=True)
        
        encoder_channel_in = first_conv_out
        skip_connection_channel = []
        # encoding path
        for i in range(len(encoder_blocks)):
            setattr(self, 'down_block_' + str(i), dense_block(num_units=encoder_blocks[i], 
                                                              encoder_channel_in, growth_rate))
            setattr(self, 'trans_down_' + str(i), nn.MaxPool2d(2, 2))
            encoder_channel_in = encoder_channel_in + encoder_blocks[i]*growth_rate
            skip_connection_channel.append(encoder_channel_in)
            
        self.bottle_block = dense_block(num_units=bottle_block, encoder_channel_in, growth_rate)
        
        # decoding path
        decoder_channel_in = bottle_block*growth_rate
        for i in range(len(up_blocks)):
            setattr(self, 'trans_up' + str(i), nn.ConvTranspose2d(decoder_channel_in, decoder_channel_in,
                                                                 3, 2, padding=1, output_padding=1))
            decoder_channel_in = decoder_channel_in + skip_connection_channel[-(i+1)]
            setattr(self, 'up_block_' + str(i), dense_block(num_units=up_blocks[i],decoder_channel_in, 
                                                            growth_rate, keep_input=False))
        # final conv
        
    def forward(self, inputs):
        output = self.conv0(inputs)
        
        # encoding path
        self.encoding_blocaks_out = []
        for i in range(len(self.down_blocks)):
            encode_block = getattr(self, 'down_block_' + str(i))
            td_layer = getattr(self, 'trans_down_' + str(i))
            
            output = encode_block(output)
            self.encoding_blocaks_out.append(ouput)
            output = td_layer(output)
            
        output = self.bottle_block(output)
        
        for i in range(len(self.up_blocks)):
            decode_block = getattr(self, 'up_block_' + str(i))
            tu_layer = getattr(self, 'trans_up_' + str(i))
            
            output = tu_layer(output)
            
            skip_out = self.encoding_blocaks_out.pop()
            output = torch.cat([skip_out, output], dim=1)
            
            output = decode_block(output)
            
        output = self.final_conv()
            