In [40]:
import torch as tc, sys, math
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import models

In [41]:
class Bottleneck(nn.Module):
    def __init__(self, num_channels, growth_rate):
        super(Bottleneck, self).__init__()
        inter_channels     = 4*growth_rate
        self.batch_norm1   = nn.BatchNorm2d(num_channels)
        self.conv1         = nn.Conv2d(num_channels, inter_channels, kernel_size = 1, bias = False)
        self.batch_norm2   = nn.BatchNorm2d(inter_channels)
        self.conv2         = nn.Conv2d(inter_channels, growth_rate, kernel_size = 3, padding = 1, bias = False)
    
    def forward(self, x):
        output = self.conv1(F.relu(self.batch_norm1(x)))
        output = self.conv2(F.relu(self.batch_norm2(output)))
        output = tc.cat((x, output), 1)
        return output

In [42]:
class SingleLayer(nn.Module):
    def __init__(self, num_channels, growth_rate):
        super(SingleLayer, self).__init__()
        self.batch_norm = nn.BatchNorm2d(num_channels)
        self.conv       = nn.Conv2d(num_channels, growth_rate, kernel_size = 3, padding = 1, bias = False)
        
    def forward(self, x):
        return tc.cat((x, self.conv(F.relu(self.batch_norm(x)))), 1)
        

In [43]:
class Transition(nn.Module):
    def __init__(self, num_channels, num_out_channels):
        super(Transition, self).__init__()
        self.batch_norm = nn.BatchNorm2d(num_channels)
        self.conv       = nn.Conv2d(num_channels, num_out_channels, kernel_size = 1, bias = False)
        
    def forward(self, x):
        return F.avg_pool2d(self.conv(F.relu(self.batch_norm(x))), 2)
        

In [44]:
class DenseNet(nn.Module):
    def __init__(self, bottleneck, num_blocks, growth_rate = 12, reduction = 0.5, num_classes = 10):
        super(DenseNet, self).__init__()        
        
        self.growth_rate = growth_rate
        num_channels = 2 * growth_rate
        
        self.conv    = nn.Conv2d(3, num_channels, kernel_size = 3, padding = 1, bias = False)
        self.dense1  = self._make_dense(num_channels, num_blocks[0], bottleneck)
        num_channels += num_blocks[0] * growth_rate
        num_out_channels = int(math.floor(num_channels * reduction))
        self.trans1  = Transition(num_channels, num_out_channels)       
        num_channels = num_out_channels
        
        self.dense2  = self._make_dense(num_channels, num_blocks[1], bottleneck)
        num_channels += num_blocks[1] * growth_rate
        num_out_channels = int(math.floor(num_channels * reduction))
        self.trans2  = Transition(num_channels, num_out_channels)       
        num_channels = num_out_channels
        
        self.dense3  = self._make_dense(num_channels, num_blocks[2], bottleneck)
        num_channels += num_blocks[2] * growth_rate
        num_out_channels = int(math.floor(num_channels * reduction))
        self.trans3  = Transition(num_channels, num_out_channels)       
        num_channels = num_out_channels
        
        self.dense4  = self._make_dense(num_channels, num_blocks[3], bottleneck)
        num_channels += num_blocks[3] * growth_rate        
        self.batch_norm = nn.BatchNorm2d(num_channels)
        self.fc = nn.Linear(num_channels, num_classes)                
    
    def _make_dense(self, num_channels, num_blocks, bottleneck):
        layers = []
        for block in range(int(num_blocks)):
            if bottleneck:
                layers.append(Bottleneck(num_channels, self.growth_rate))
            else:
                layers.append(SingleLayer(num_channels, self.growth_rate))
            num_channels += self.growth_rate
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.trans1(self.dense1(x))
        x = self.trans2(self.dense2(x))
        x = self.trans3(self.dense3(x))
        x = self.dense4(x)
        x = F.avg_pool2d(F.relu(self.batch_norm(x)), 4)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



In [45]:
def denseNet():
    return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)

def check():
    model = denseNet()
    x = tc.randn(1, 3, 32, 32)
    y = model(x)
    return y
check()

tensor([[ 0.3411, -0.0695,  0.1289, -0.1565,  0.2002, -0.0652, -0.1514, -0.0220,
          0.3215,  0.0416]], grad_fn=<AddmmBackward>)