In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
class Bottleneck(nn.Module):
    def __init__(self, nChannels, GrowthRate):
        super(Bottleneck, self).__init__()
        interChannels = 4 * GrowthRate
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(interChannels)
        self.conv2 = nn.Conv2d(interChannels, GrowthRate, kernel_size=3, padding=1, bias=False)
        
    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        out = torch.cat((x, out), 1)
        return out

In [4]:
class Denseblock(nn.Module):
    def __init__(self, nChannels, GrowthRate, nDenseBlocks):
        super(Denseblock, self).__init__()
        layers = []
        for i in range(int(nDenseBlocks)):
            layers.append(Bottleneck(nChannels, GrowthRate))
            nChannels += GrowthRate
        self.denseblock = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.denseblock(x)

In [5]:
denseblock = Denseblock(64, 32, 6)

In [6]:
denseblock

Denseblock(
  (denseblock): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (2): Bottleneck(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=F

In [7]:
data = torch.randn(1, 64, 256, 256)
out = denseblock(data)
out.shape

torch.Size([1, 256, 256, 256])