In [22]:
import torch
import torch.nn as nn

In [23]:
# Densenet block:

# batchnorm
# ReLU
# convolution 1х1

# batchnorm
# ReLU
# convolution 3х3

class DenseBlockLayer(nn.Module):
    def __init__(self,in_channels,out_channels,identity_downsample=False):
        super(DenseBlockLayer,self).__init__()
        self.inter_planes = 4*out_channels

        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels,self.inter_planes,kernel_size=1,stride=1)
        self.bn2 = nn.BatchNorm2d(self.inter_planes)
        self.conv2 = nn.Conv2d(self.inter_planes,out_channels,kernel_size=3,stride=1,padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.identity_downsample = identity_downsample
        self.concatenation = None

    # x out_channels -> x*4 out_channels, x*4 out_channels -> x out_channels
    def forward(self,x):
      self.concatenation = x.clone()

      x = self.conv1(self.relu(self.bn1(x)))
      x = self.conv2(self.relu(self.bn2(x)))

      if self.identity_downsample:
        x = torch.cat((x,self.concatenation), dim=1)

      return x

In [24]:
# This layer should reduce the number of channels by dividing them into 2
class TransitionLayer(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(TransitionLayer,self).__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1)
        self.avg_pool = nn.AvgPool2d(kernel_size=2,stride=2)

    def forward(self,x):
        x = self.avg_pool(self.conv(self.relu(self.bn(x))))
        return x

In [25]:
#The main difference with ResNet is that all previous layers are concatenated with the current one.
class DenseNet(nn.Module):
    def __init__(self,block,layers, image_channels, num_classes):
        super(DenseNet,self).__init__()
        self.out_channels = 32
        # First block like in ResNet,
        # it has 7х7 convolutional
        self.conv1 = nn.Conv2d(image_channels,self.out_channels,kernel_size=7,stride=2)
        self.bn1 = nn.BatchNorm2d(self.out_channels)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2)

        #DenseBlock layers
        self.layer1 = self.make_block(block,layers[0],in_channels=self.out_channels)
        self.layer2 = self.make_block(block,layers[1],in_channels=96)
        self.layer3 = self.make_block(block,layers[2],in_channels=224)
        self.layer4 = self.make_block(block,layers[3],in_channels=480)

        # Before fully connected layer we use AveregePooling
        self.glob_avg_pool = nn.AdaptiveAvgPool2d((7,7))
        self.fc = nn.Linear(480*7*7,num_classes)


    # Layer creation function
    def make_block(self, block, num_residual_blocks,in_channels):
        layers = []
    
        # The 1st layer of each block does not have concatenation
        layers.append(block(in_channels,in_channels))

        # Each layer increases input by 32
        for i in range(num_residual_blocks-1):
            layers.append(block(in_channels,32,identity_downsample=True))
            in_channels+=32

        layers.append(TransitionLayer(in_channels,in_channels//2))
        return nn.Sequential(*layers)

    def forward(self,x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.glob_avg_pool(x)
        x = x.reshape(x.shape[0],-1)
        x = self.fc(x)
        return x

    

In [26]:
# DenseNet-121
def DenseNet121(img_channels=3,num_classes=1000):
    return DenseNet(DenseBlockLayer,[6,12,24,16],img_channels,num_classes)

In [27]:
# Test
def test():
  net = DenseNet121()
  x = torch.rand(2,3,224,224)
  y = net(x)
  print(y.shape)

In [28]:
test()

torch.Size([2, 1000])
