# DenseNet的Pytorch实现

In [1]:
import time
import torch
from torch import nn, optim
import torchvision
from torchsummary import summary
import torch.nn.functional as F


In [30]:
class DenseBlock(nn.Module):
    def __init__(self, num_convs, growth_rate, num_input_features):
        super(DenseBlock, self).__init__()
        self.num_convs = num_convs
        blocks = []
        for i in range(num_convs):
            in_channels = num_input_features + i * growth_rate
            blocks.append(self.conv_block(in_channels, growth_rate))
        self.blocks = nn.ModuleList(blocks)
    
    def conv_block(self, in_channels, out_channels):
        conv = nn.Sequential(
            nn.BatchNorm2d(num_features=in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False)
        )
        return conv
    
    def forward(self, x):
        for block in self.blocks:
            y = block(x)
            x = torch.cat((x, y), dim=1)
        return x
            

In [31]:
class DenesNetModel(nn.Module):
    def __init__(self, growth_rate, block_config, num_init_features, classes):
        super(DenesNetModel, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        
        num_features = num_init_features 
        self.dense_block1 = DenseBlock(num_convs=block_config[0], growth_rate=growth_rate,
                                      num_input_features=num_features)
        num_features = num_features + growth_rate * block_config[0]
        self.trans1 = self.transition_layer(in_channels=num_features, out_channels=num_features // 2)
        
        num_features = num_features // 2 
        self.dense_block2 = DenseBlock(num_convs=block_config[1], growth_rate=growth_rate,
                                      num_input_features=num_features)
        num_features = num_features  + growth_rate * block_config[1]
        self.trans2 = self.transition_layer(in_channels=num_features, out_channels=num_features // 2)
        
        num_features = num_features // 2 
        self.dense_block3 = DenseBlock(num_convs=block_config[2], growth_rate=growth_rate,
                                      num_input_features=num_features)
        num_features = num_features + growth_rate * block_config[2]
        self.trans3 = self.transition_layer(in_channels=num_features, out_channels=num_features // 2)
        
        num_features = num_features // 2 
        self.dense_block4 = DenseBlock(num_convs=block_config[3], growth_rate=growth_rate,
                                      num_input_features=num_features)
        num_features = num_features + growth_rate * block_config[3]
        
        self.avgpool = nn.AvgPool2d(kernel_size=6)
        self.fc = nn.Linear(num_features, classes)
    
    def transition_layer(self, in_channels, out_channels):
        transition = nn.Sequential(
            nn.BatchNorm2d(num_features=in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        return transition
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.dense_block1(x)
        x = self.trans1(x)
        x = self.dense_block2(x)
        x = self.trans2(x)
        x = self.dense_block3(x)
        x = self.trans3(x)
        x = self.dense_block4(x)
        x = self.avgpool(x)
        x = self.fc(x.view(x.shape[0], -1))

#growth_rate = 32, block_config=[6, 12, 32, 32], num_init_features=64, classes10
net = DenesNetModel(growth_rate = 32, block_config=[6, 12, 32, 32], num_init_features=64, classes=10)
summary(net, (3, 224, 224), device="cpu")
        
        
        
        

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 55, 55]               0
       BatchNorm2d-5           [-1, 64, 55, 55]             128
              ReLU-6           [-1, 64, 55, 55]               0
            Conv2d-7           [-1, 32, 55, 55]           2,048
       BatchNorm2d-8           [-1, 32, 55, 55]              64
              ReLU-9           [-1, 32, 55, 55]               0
           Conv2d-10           [-1, 32, 55, 55]           9,216
      BatchNorm2d-11           [-1, 96, 55, 55]             192
             ReLU-12           [-1, 96, 55, 55]               0
           Conv2d-13           [-1, 32, 55, 55]           3,072
      BatchNorm2d-14           [-1, 32,

Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 253.12
Params size (MB): 15.19
Estimated Total Size (MB): 268.89
----------------------------------------------------------------
