In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict


x = torch.randn(1, 3, 224, 224)


In [11]:
class DenseLayer(nn.Module):
    def __init__(self, in_feature, growth_rate):
        super().__init__()
        inter_channel = 4*growth_rate
        self.bn1 = nn.BatchNorm2d(in_feature)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels=in_feature, out_channels=inter_channel, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(inter_channel)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=inter_channel, out_channels=growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
    
    def forward(self, x):
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.conv2(self.relu2(self.bn2(out)))
        return out # torch.cat([x, out], 1)

In [16]:
class DenseBlock(nn.Module):
    def __init__(self, in_feature, num_layer, growth_rate):
        super().__init__()
        for idx in range(num_layer):
            input_feature = in_feature + idx * growth_rate
            layer = DenseLayer(input_feature, growth_rate)
            self.add_module("denselayer%d"%(idx+1), layer)
    def forward(self, init_feature):
        features = [init_feature]
        for name, layer in self.named_children():
            new_features = layer(features)
            # features.append(new_features)
            # print(new_features.shape)
        return new_features # torch.cat(features, 1)

In [17]:
class TransitionBlock(nn.Module):
    def __init__(self, in_feature, compression):
        super(TransitionBlock, self).__init__()
        out_feature = int(in_feature*compression)
        self.bn = nn.BatchNorm2d(in_feature)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels=in_feature, out_channels=out_feature, kernel_size=1, stride=1, padding=0, bias=False)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        
    def forward(self, x):
        out = self.pool(self.conv(self.relu(self.bn(x))))
        return out

In [18]:
class DenseNet(nn.Module):
    def __init__(self, in_feature=3, num_classes=1000, num_dense_block=[6, 12, 24, 16], growth_rate=32, compression=0.5):
        super(DenseNet, self).__init__()
        initial_feature = 2*growth_rate
        self.encoder = nn.Sequential(
            OrderedDict([
                ("conv0", nn.Conv2d(in_channels=in_feature, out_channels=initial_feature, kernel_size=7, stride=2, padding=3, bias=False)), 
                ("bn0", nn.BatchNorm2d(num_features=initial_feature)),
                ("relu0", nn.ReLU(inplace=True)),
                ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
            ])
        )
        
        for idx, num_layer in enumerate(num_dense_block):
            self.encoder.add_module("denseblock%d"%(idx+1), DenseBlock(initial_feature, num_layer, growth_rate))
            initial_feature = initial_feature + num_layer * growth_rate
            if idx != len(num_dense_block)-1:
                self.encoder.add_module("transition%d"%(idx+1), TransitionBlock(initial_feature, compression))
                initial_feature = initial_feature//2
        
        self.bn1 = nn.BatchNorm2d(initial_feature)
        self.avgpool1 = nn.AvgPool2d(kernel_size=7, stride=3)
        self.classifier = nn.Linear(in_features=initial_feature, out_features=num_classes, bias=True)
        
    
    def forward(self, x):
        out = self.encoder(x)
        out = self.avgpool1(self.bn1(out))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

In [19]:
dense = DenseBlock(64, 6, 32)
out = dense(torch.randn(2, 64, 56, 56))
out.shape

AttributeError: 'list' object has no attribute 'dim'

In [16]:
trans = TransitionBlock(1120, 0.5)
out = trans(out)
out.shape

torch.Size([2, 560, 28, 28])

In [8]:
x = torch.randn(2, 3, 224, 224)
model = DenseNet()
out = model(x)

RuntimeError: running_mean should contain 1120 elements not 256

In [9]:
model

DenseNet(
  (encoder): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): DenseBlock(
      (denselayer1): DenseLayer(
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): DenseLayer(
        (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inp