In [1]:
import torch
from torch import nn
from torchinfo import summary

In [19]:
class DenseBlock(nn.Module):
    def __init__(self, in_channel, k):
        super().__init__()

        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(),
            nn.Conv2d(in_channel, 4*k, 1, bias=False),

            nn.BatchNorm2d(4*k),
            nn.ReLU(),
            nn.Conv2d(4*k, k, 3, padding=1, bias=False)
        )

    def forward(self, x):
        output = torch.cat([x, self.residual(x)], dim=1)
        return output

# in_channel = 32
# model = DenseBlock(in_channel, 4)
# summary(model, (2, in_channel, 56,56))

In [20]:
class TransitionBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()

        self.trainsition = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(),
            nn.Conv2d(in_channel, out_channel, 1, bias=False),
            nn.AvgPool2d(kernel_size=2)
        )
    
    def forward(self, x):
        x = self.trainsition(x)
        return x

In [35]:
class DenseNet(nn.Module):
    def __init__(self, block_num_list, growth_rate, num_classes):
        super().__init__()

        self.k = growth_rate
        self.in_channel = 2*growth_rate

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, self.in_channel, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(self.in_channel),
            nn.ReLU()
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.dense_block_1 = self.make_dense_block(block_num_list[0])
        self.transition_block_1 = self.make_transition_block()

        self.dense_block_2 = self.make_dense_block(block_num_list[1])
        self.transition_block_2 = self.make_transition_block()

        self.dense_block_3 = self.make_dense_block(block_num_list[2])
        self.transition_block_3 = self.make_transition_block()

        self.dense_block_4 = self.make_dense_block(block_num_list[3])
        self.bn = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU()

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channel, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)

        x = self.dense_block_1(x)
        x = self.transition_block_1(x)

        x = self.dense_block_2(x)
        x = self.transition_block_2(x)

        x = self.dense_block_3(x)
        x = self.transition_block_3(x)

        x = self.dense_block_4(x)
        x = self.bn(x)
        x = self.relu(x)

        x = self.GlobalAvgPool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        
        return x
    
    def make_dense_block(self, block_num):
        dense_block = []

        for _ in range(block_num):
            dense_block += [DenseBlock(self.in_channel, k=self.k)]
            self.in_channel += self.k
        
        return nn.Sequential(*dense_block)
    
    def make_transition_block(self):
        reduced_channel = int(self.in_channel*0.5)
        transition_block = TransitionBlock(self.in_channel, reduced_channel)
        self.in_channel = reduced_channel
        return transition_block

# model = DenseNet(block_num_list=[6, 12, 64, 48], growth_rate=32, num_classes=1000)
# summary(model, (2, 3, 224, 224))

In [37]:
def DenseNet121():
    return DenseNet(block_num_list=[6,12,24,16], growth_rate=32, num_classes=1000)

def DenseNet169(**kwargs):
    return DenseNet(block_num_list=[6,12,32,32], growth_rate=32, num_classes=1000)

def DenseNet201(**kwargs):
    return DenseNet(block_num_list=[6,12,48,32], growth_rate=32, num_classes=1000)

def DenseNet264(**kwargs):
    return DenseNet(block_num_list=[6,12,64,48], growth_rate=32, num_classes=1000)

In [38]:
model = DenseNet264()
summary(model, (2, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
DenseNet                                 [2, 1000]                 --
├─Sequential: 1-1                        [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [2, 64, 112, 112]         128
│    └─ReLU: 2-3                         [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [2, 64, 56, 56]           --
├─Sequential: 1-3                        [2, 256, 56, 56]          --
│    └─DenseBlock: 2-4                   [2, 96, 56, 56]           --
│    │    └─Sequential: 3-1              [2, 32, 56, 56]           45,440
│    └─DenseBlock: 2-5                   [2, 128, 56, 56]          --
│    │    └─Sequential: 3-2              [2, 32, 56, 56]           49,600
│    └─DenseBlock: 2-6                   [2, 160, 56, 56]          --
│    │    └─Sequential: 3-3              [2, 32, 56, 56]           53,760