**DenseNet: Densely Connected Convolutional Networks**   
*Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger*  
[[paper](https://arxiv.org/abs/1608.06993)]    
CVPR 2017



In [1]:
# for ImageNet 
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                    convKernel_size=3, convStride=2) -> None:
        super(ConvLayer, self).__init__()

        self.norm = nn.BatchNorm2d(in_dim)
        self.act  = nn.ReLU()
        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=convKernel_size, stride=convStride, padding=convKernel_size//2) # padding=convKenrel_size//2 --> keep size of the input image

    def forward(self, inputs):
        # H() = BN - ReLU - Conv
        h = self.norm(inputs)
        h = self.act(h)
        h = self.conv(h)

        return h

In [2]:
class InitialConv(nn.Module):
    def __init__(self, in_dim, out_dim,
                    convKernel_size=7, convStride=2, poolingKernel_size=3, poolingStride=2) -> None:
        super(InitialConv, self).__init__()

        self.composite = ConvLayer(in_dim, out_dim, convKernel_size, convStride)
        self.pool = nn.MaxPool2d(poolingKernel_size, stride=poolingStride)

    def forward(self, inputs):

        h = self.composite(inputs)
        h = self.pool

        return h

In [None]:
class DenseBlock(nn.Module):
    def __init__(self, in_dim, out_dim, num_conv=6,
                    convKernel_size=3, convStride=2, poolingKernel_size=3, poolingStride=2) -> None:
        super(DenseBlock, self).__init__()

        self.num_conv = num_conv
        self.consecutive = nn.ModuleList([])

        for idx in range(self.num_conv):
            # 1x1 convolution
            self.consecutive.append(ConvLayer(in_dim*idx, out_dim, convKernel_size=1, convStride=1))
            # 3x3 convlotuon
            self.consecutive.append(ConvLayer(out_dim, out_dim, convKernel_size=3, convStride=2))

    def forward(self, inputs):

        prev = inputs
        inter = inputs

        ## 1x1
        inter = self.consecutive[0](inter)
        ## 3x3
        inter = self.consecutive[1](inter)
        prev = torch.concat((prev, inter), dim=1)

        for idx in range(2, self.num_conv):
            