# 稠密连接网络 DenseNet
稠密连接网络是残差网络的扩展

主要由两个部分构成：稠密块和过渡层

In [5]:
import torch as t
import torch.nn as nn

def conv_block(input_channels,num_channels)->nn.Sequential:
    return nn.Sequential(
        nn.BatchNorm2d(input_channels),
        nn.ReLU(),
        nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1)
    )

class DenseBlock(nn.Module):
    def __init__(self,num_convs,input_channels,num_channels):
        super().__init__()
        layer=[]
        for i in range(num_convs):
            layer.append(conv_block(num_channels*i+input_channels,num_channels))
        self.net = nn.Sequential(*layer)
    
    def forward(self,x:t.Tensor)->t.Tensor:
        for blk in self.net:
            Y=blk(x)
            x=t.cat((x,Y),dim=1)
        return x

In [6]:
blk = DenseBlock(2, 3, 10)
X = t.randn(4, 3, 8, 8)
Y = blk(X)
Y.shape

torch.Size([4, 23, 8, 8])

In [7]:
def transition_block(input_channels,num_channels):
    return nn.Sequential(
        nn.BatchNorm2d(input_channels),
        nn.ReLU(),
        nn.Conv2d(input_channels,num_channels,kernel_size=1),
        nn.AvgPool2d(kernel_size=2,stride=2)
    )


blk = transition_block(23, 10)
blk(Y).shape


torch.Size([4, 10, 4, 4])

In [8]:
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

num_channels,growth_rate = 64,32
num_convs_in_dense_blocks =[4,4,4,4]
blks=[]
for i ,num_convs in enumerate(num_convs_in_dense_blocks):
    blks.append(DenseBlock(num_convs,num_channels,growth_rate))
    num_channels+=num_convs*growth_rate

    if i!=len(num_convs_in_dense_blocks):
        blks.append(transition_block(num_channels,num_channels//2))
        num_channels=num_channels//2
net = nn.Sequential(
    b1, *blks,
    nn.BatchNorm2d(num_channels), nn.ReLU(),
    nn.AdaptiveMaxPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(num_channels, 10))

In [9]:
X = t.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

Sequential output shape:	 torch.Size([1, 64, 56, 56])
DenseBlock output shape:	 torch.Size([1, 192, 56, 56])
Sequential output shape:	 torch.Size([1, 96, 28, 28])
DenseBlock output shape:	 torch.Size([1, 224, 28, 28])
Sequential output shape:	 torch.Size([1, 112, 14, 14])
DenseBlock output shape:	 torch.Size([1, 240, 14, 14])
Sequential output shape:	 torch.Size([1, 120, 7, 7])
DenseBlock output shape:	 torch.Size([1, 248, 7, 7])
Sequential output shape:	 torch.Size([1, 124, 3, 3])
BatchNorm2d output shape:	 torch.Size([1, 124, 3, 3])
ReLU output shape:	 torch.Size([1, 124, 3, 3])
AdaptiveMaxPool2d output shape:	 torch.Size([1, 124, 1, 1])
Flatten output shape:	 torch.Size([1, 124])
Linear output shape:	 torch.Size([1, 10])
