In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [34]:
class TransitionLayer(nn.Module):
    def __init__(self,num_features,reduction):
        super().__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(num_features),
            nn.ReLU(),
            nn.Conv2d(num_features,int(num_features * reduction),kernel_size=1,stride=1,bias=False),
            nn.AvgPool2d(kernel_size=2,stride=2),
        )
    def forward(self,x):
        return self.conv(x)

In [35]:
x = torch.randn(1,32,56,56)
TransitionLayer(num_features=32,reduction=0.5)(x).shape

torch.Size([1, 16, 28, 28])

In [24]:
class DenseLayer(nn.Module):
    def __init__(self,num_features,growth_rate,bottleneck_size,dropout=0.2):
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.BatchNorm2d(num_features),
            nn.ReLU(),
            nn.Conv2d(num_features,bottleneck_size * growth_rate,kernel_size=1,stride=1,bias=False),
        ) 
        self.conv = nn.Sequential(
            nn.BatchNorm2d(bottleneck_size * growth_rate),
            nn.ReLU(),
            nn.Conv2d(bottleneck_size * growth_rate, growth_rate,kernel_size=3,stride=1,padding=1,bias=False),
            nn.Dropout(dropout),
        )
        
    def forward(self,input):
        prev_features = input
        prev_features = torch.cat(prev_features,dim=1)
        bottleneck_out = self.bottleneck(prev_features)
        out_features = self.conv(bottleneck_out)
        
        return out_features

In [25]:
class DenseBlock(nn.Module):
    def __init__(self,num_layers,num_features,bottleneck_size,growth_rate):
        super().__init__()
        self.layers = nn.ModuleList([DenseLayer(num_features + i * growth_rate,growth_rate,bottleneck_size) for i in range(num_layers)])
    
    def forward(self,init_features):
        features = [init_features]
        for layer in self.layers:
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features,dim=1)

In [58]:
class DenseNet(nn.Module):
    def __init__(self,num_init_features,bottleneck_size,growth_rate,reduction=0.5,blocks=[6,12,24,16],num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Conv2d(1,num_init_features,kernel_size=3,padding=1),
            nn.MaxPool2d(kernel_size=2,stride=2),
        )
        num_features = num_init_features
        
        self.block1 = DenseBlock(blocks[0],num_features,bottleneck_size,growth_rate)
        num_features = num_features + growth_rate * blocks[0]
        self.trans1 = TransitionLayer(num_features,reduction)
        num_features = int(num_features * reduction)
        
        self.block2 = DenseBlock(blocks[1],num_features,bottleneck_size,growth_rate)
        num_features = num_features + growth_rate * blocks[1]
        self.trans2 = TransitionLayer(num_features,reduction)
        num_features = int(num_features * reduction)
        
        self.block3 = DenseBlock(blocks[2],num_features,bottleneck_size,growth_rate)
        num_features = num_features + growth_rate * blocks[2]
        self.trans3 = TransitionLayer(num_features,reduction)
        num_features = int(num_features * reduction)
        
        self.block4 = DenseBlock(blocks[3],num_features,bottleneck_size,growth_rate)
        num_features = num_features + growth_rate * blocks[3]
        
        self.pool = nn.AvgPool2d(4)
        self.head = nn.Linear(num_features,num_classes)

        
    def forward(self,x):
        x = self.features(x)
        
        x = self.block1(x)
        x = self.trans1(x)
        
        x = self.block2(x)
        x = self.trans2(x)
        
        x = self.block3(x)
        x = self.trans3(x)
        
        x = self.block4(x)
        
        x = self.pool(x)
        
        x = x.view(-1)
        
        x = self.head(x)
        return x

In [59]:
model = DenseNet(num_init_features=64,bottleneck_size=4,growth_rate=32)

In [60]:
sum([p.numel() for p in model.parameters()])

6953164

In [61]:
x = torch.randn(1,1,64,64)

In [62]:
model(x)

tensor([-0.0873,  0.0203, -0.0730,  0.0750, -0.0294,  0.0793,  0.3074, -0.1312,
         0.0450, -0.0148], grad_fn=<ViewBackward0>)

In [46]:
net = nn.Sequential(
    nn.Conv2d(64,128,kernel_size=1,stride=1,bias=False),
    nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1,bias=False),
)

In [26]:
x = torch.randn(1,64,56,56)
net(x).shape

torch.Size([1, 64, 56, 56])