In [1]:
import torch
from torch import nn, optim
import math
import os
from torch.nn import functional as F
from torchinfo import summary

In [2]:
class DenseLayer(nn.Module):
    def __init__(self, n_in, growth_rate, bn_size):
        super(DenseLayer, self).__init__()
        self.layers = nn.Sequential(
            nn.BatchNorm2d(n_in),
            nn.ReLU(),
            nn.Conv2d(n_in, bn_size * growth_rate, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(bn_size * growth_rate),
            nn.ReLU(),
            nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, 
                      padding=1, bias=False),
        )
    def forward(self, x):
        concated_features = torch.cat(x, 1)
        bottleneck_out = self.layers(concated_features)
        return bottleneck_out

In [3]:
class DenseBlock(nn.Module):
    def __init__(self, n_layers, n_in, bn_size, growth_rate, drop_rate):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(n_layers):
            layer = DenseLayer(n_in + i * growth_rate, growth_rate, bn_size)
            self.layers.append(layer)
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)

In [4]:
class TransitionBlock(nn.Module):
    def __init__(self, n_in, n_out):
        super(TransitionBlock, self).__init__()
        self.conv = nn.Conv2d(n_in, n_out, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(n_in)
    def forward(self, x):
        x = self.conv(F.relu(self.bn(x)))
        x = F.avg_pool2d(x, 2)
        return x

In [5]:
class DenseNet(nn.Module):
    def __init__(self, n_classes=1000, growth_rate=32, blocks=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0):
        super(DenseNet, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        features = []
        
        n_feats = num_init_features
        for i, n_layers in enumerate(blocks):
            block = DenseBlock(n_layers, n_feats, bn_size, growth_rate, drop_rate=drop_rate)
            n_feats += n_layers * growth_rate
            features.append(block)
            if i != len(blocks) - 1:
                transition = TransitionBlock(n_feats, n_feats//2)
                n_feats = n_feats // 2
                features.append(transition)
        
        features.append(nn.BatchNorm2d(n_feats))
        self.features = nn.Sequential(*features)
        self.avgpool = nn.AdaptiveMaxPool2d(1)
        self.head = nn.Sequential(
            nn.Linear(n_feats, n_classes)
        )
        
    def forward(self, x):
        x = self.stem(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x

In [6]:
def densenet121(n_classes=1000):
    return DenseNet(n_classes=n_classes)
def densenet169(n_classes=1000):
    return DenseNet(n_classes=n_classes, growth_rate=32, blocks=(6, 12, 32, 32), num_init_features=64)
def densenet201(n_classes=1000):
    return DenseNet(n_classes=n_classes, growth_rate=32, blocks=(6, 12, 48, 32), num_init_features=64)
def densenet264(n_classes=1000):
    return DenseNet(n_classes=n_classes, growth_rate=32, blocks=(6, 12, 64, 48), num_init_features=64)

In [7]:
d121 = densenet121()
d169 = densenet169()
d201 = densenet201()
d264 = densenet264()

In [8]:
models = [d121, d169, d201, d264]

In [9]:
%%time
inp = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    d121(inp).shape, d201(inp).shape

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


CPU times: user 398 ms, sys: 92.4 ms, total: 491 ms
Wall time: 217 ms


In [10]:
def fmat(n):
    return "{:.2f}M".format(n / 1e6)

In [11]:
def params(model, f=True):
    s = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return fmat(s) if f else s

In [12]:
import os
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [13]:
for m in models:
    print(params(m))

7.98M
14.15M
20.01M
33.34M


In [14]:
for m in models:
    print_size_of_model(m)

Size (MB): 32.499987
Size (MB): 57.581538
Size (MB): 81.389026
Size (MB): 135.496162


In [15]:
summary(d121, (1, 3, 224, 224))

Layer (type:depth-idx)                             Output Shape              Param #
DenseNet                                           --                        --
├─Sequential: 1                                    --                        --
│    └─DenseBlock: 2                               --                        --
│    │    └─ModuleList: 3-1                        --                        335,040
│    └─DenseBlock: 2                               --                        --
│    │    └─ModuleList: 3-2                        --                        919,680
│    └─DenseBlock: 2                               --                        --
│    │    └─ModuleList: 3-3                        --                        2,837,760
│    └─DenseBlock: 2                               --                        --
│    │    └─ModuleList: 3-4                        --                        2,158,080
├─Sequential: 1-1                                  [1, 64, 56, 56]           --
│    └─Conv