<a href="https://colab.research.google.com/github/TimofeyKulakov/NeuralNets/blob/master/DenseNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Model

In [None]:
def test_module(module, test_size = (5, 3, 224, 224)):
  test = torch.randn(test_size)
  return module(test).shape

In [None]:
class Bottleneck(nn.Module):
  def __init__(self, input_channels, growth_rate):
    super(Bottleneck, self).__init__()

    inter_channels = 4 * growth_rate

    self.bn1 = nn.BatchNorm2d(input_channels)
    self.conv1 = nn.Conv2d(input_channels, inter_channels, 1, bias= False)
    self.bn2 = nn.BatchNorm2d(inter_channels)
    self.conv2 = nn.Conv2d(inter_channels, growth_rate, 3, padding = 1, bias = False)

  def forward(self, x):
    out = self.conv1(F.relu(self.bn1(x)))
    out = self.conv2(F.relu(self.bn2(out)))
    out = torch.cat((x, out), dim = 1)
    return out

In [None]:
class Transition(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Transition, self).__init__()

    self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
    self.avgpool = nn.AvgPool2d((2,2), stride = 2)

  def forward(self, x):
    out = self.avgpool(self.conv1(x))
    return out

In [None]:
class DenseLayer(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DenseLayer, self).__init__()

    self.bn1 = nn.BatchNorm2d(in_channels)
    self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding = 1)

  def forward(self, x):
    out = self.bn1(x)
    out = self.conv1(F.relu(out))
    return out

In [None]:
class DenseBlock(nn.Module):
  def __init__(self, in_channels, growth_rate, num_layers = 4):
    super(DenseBlock, self).__init__()

    self.mod = nn.ModuleList([DenseLayer(in_channels + growth_rate * i, growth_rate) for i in range(num_layers)])

  def forward(self, x):
    outputs = [x]
    for i, module in enumerate(self.mod):
      if i == 0:
        outputs.append(module(x))
      else:
        outputs.append(module(torch.cat(outputs, dim = 1)))
        
    return torch.cat(outputs, dim = 1)


In [None]:
# Test modules
print(test_module(Transition(16, 32), (5, 16, 100, 100)))
print(test_module(Bottleneck(3, 4)))
print(test_module(DenseLayer(3, 8)))
print(test_module(DenseBlock(3, 4, 6)))

torch.Size([5, 32, 50, 50])
torch.Size([5, 7, 224, 224])
torch.Size([5, 8, 224, 224])
torch.Size([5, 27, 224, 224])


In [None]:
class DenseNet(nn.Module):
  def __init__(self, in_channels, growth_rate, num_classes, dense_layers_num = [6, 12, 24, 16]):
    super(DenseNet, self).__init__()

    assert isinstance(dense_layers_num, (list, tuple)), 'dense_layers_num must be list or tuple containing numbers of dense layers in each dense block,  e.g. [2, 4, 6]'

    self.conv1 = nn.Conv2d(in_channels, 2 * growth_rate, (7, 7), stride = 2, padding = 3)
    self.maxpool1 = nn.MaxPool2d(3, stride = 2, padding = 1)

    modules = []
    for j, i in enumerate(dense_layers_num):
      if j == 0:
        modules += [DenseBlock(2 * growth_rate, growth_rate, i), (Transition(2 * growth_rate + i * growth_rate, growth_rate))]
      else:
        modules += [DenseBlock(growth_rate, growth_rate, i), (Transition(growth_rate + i * growth_rate, growth_rate))]

    self.mods = nn.ModuleList(modules[:-1])

    self.fc = nn.Linear(growth_rate + growth_rate * dense_layers_num[-1], num_classes)

  def forward(self, x):
    out = self.maxpool1(self.conv1(x))

    for m in self.mods:
       out = m(out)
    out = out.mean([2, 3])
    out = self.fc(out)
    out = torch.softmax(out, dim = 1)
    return out



In [None]:
#Test net
test_module(DenseNet(3, 32, 10))

torch.Size([5, 10])