**DenseNet: Densely Connected Convolutional Networks**   
*Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger*  
[[paper](https://arxiv.org/abs/1608.06993)]    
CVPR 2017



In [3]:
# for ImageNet 
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [4]:
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                    convKernel_size=3, convStride=2) -> None:
        super(ConvLayer, self).__init__()

        self.norm = nn.BatchNorm2d(in_dim)
        self.act  = nn.ReLU()
        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=convKernel_size, stride=convStride, padding=convKernel_size//2) # padding=convKenrel_size//2 --> keep size of the input image

    def forward(self, inputs):
        # H() = BN - ReLU - Conv
        h = self.norm(inputs)
        h = self.act(h)
        h = self.conv(h)

        return h

In [5]:
class InitialConv(nn.Module):
    def __init__(self, in_dim, out_dim,
                    convKernel_size=7, convStride=2, poolingKernel_size=3, poolingStride=2) -> None:
        super(InitialConv, self).__init__()

        self.composite = ConvLayer(in_dim, out_dim, convKernel_size, convStride)
        self.pool = nn.MaxPool2d(poolingKernel_size, stride=poolingStride)

    def forward(self, inputs):

        h = self.composite(inputs)
        h = self.pool

        return h

In [6]:
class DenseBlock(nn.Module):
    def __init__(self, in_dim, out_dim, num_conv=6,
                    convKernel_size=3, convStride=2) -> None:
        super(DenseBlock, self).__init__()

        self.num_conv = num_conv
        self.consecutive = nn.ModuleList([])

        for idx in range(self.num_conv):
            # 1x1 convolution
            self.consecutive.append(ConvLayer(in_dim+(out_dim*(idx+1)), out_dim, convKernel_size=1, convStride=1))
            # 3x3 convlotuon
            self.consecutive.append(ConvLayer(out_dim, out_dim, convKernel_size=convKernel_size, convStride=convStride))

    def forward(self, inputs):

        prev = inputs

        for idx in range(self.num_conv):
            inter = self.consecutive[idx*2](prev)      # 1x1 conv
            inter = self.consecutive[idx*2 +1](inter)   # 3x3 conv
            prev  = torch.concat([prev, inter], dim=1)

        return inter #the final layer output
            

In [7]:
class TransitionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, 
                    poolingKernel_size=2, poolingStride=2) -> None:
        super(TransitionLayer, self).__init__()

        self.conv = ConvLayer(in_dim, out_dim, convKernel_size=1, convStride=1)
        self.pool = nn.AvgPool2d(poolingKernel_size, poolingStride) # used average pool (not global)

    def forward(self, dense_out):

        h = self.conv(dense_out)
        h = self.pool(h)

        return h

In [None]:
class DenseNet(nn.Module):
    def __init__(self, in_dim, out_dim, growthRate=32) -> None:
        super(DenseNet, self).__init__()

        self.growthRate = growthRate

        self.init_conv = InitialConv(in_dim, growthRate, convKernel_size=7, convStride=2, poolingKernel_size=3, poolingStride=2)

        self.dense1      = DenseBlock(self.growth*2, self.growthRate*2, num_conv=6, convKernel_size=3, convStride=2)
        self.transition1 = TransitionLayer(self.growthRate*2, self.growthRate*4)

        self.dense2      = DenseBlock(self.growth*4, self.growthRate*4, num_conv=12, convKernel_size=3, convStride=2)
        self.transition2 = TransitionLayer(self.growthRate*4, self.growthRate*8)

        self.dense3      = DenseBlock(self.growth*8, self.growthRate*8, num_conv=12, convKernel_size=3, convStride=2)
        self.transition3 = TransitionLayer(self.growthRate*8, self.growthRate*16)

        self.dense4      = DenseBlock(self.growth*16, self.growthRate*16, num_conv=12, convKernel_size=3, convStride=2)
        self.pool        = nn.AdaptiveAvgPool2d(1)
        self.classifier  = nn.Sequential(
                        nn.Linear(self.growthRate*16, self.growthRate*16),
                        nn.Dropout(),
                        nn.ReLU(),
                        nn.Linear(self.growthRate*16, out_dim)
                    )

    def forward(self, inputs):

        h0  = self.init_conv(inputs)

        h1  = self.dense1(h0)
        h1  = self.transition1(h1)

        h2  = self.dense2(h1)
        h2  = self.transition2(h2)

        h3  = self.dense3(h2)
        h3  = self.transition3(h3)

        h4  = self.dense4(h3)

        hg  = self.pool(h4)

        out = self.classifier(hg)

        return out