In [1]:
# Codeblock 1
import torch
import torch.nn as nn

GROWTH      = 12
COMPRESSION = 0.5
REPEATS     = [6, 12, 24, 16]

In [4]:
# Codeblock 2
class Bottleneck(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)
        
        self.bn0   = nn.BatchNorm2d(num_features=in_channels)
        self.conv0 = nn.Conv2d(in_channels=in_channels, 
                               out_channels=GROWTH*4,          #(1) 
                               kernel_size=1, 
                               padding=0, 
                               bias=False)
        
        self.bn1   = nn.BatchNorm2d(num_features=GROWTH*4)
        self.conv1 = nn.Conv2d(in_channels=GROWTH*4, 
                               out_channels=GROWTH,            #(2)
                               kernel_size=3, 
                               padding=1, 
                               bias=False)
    
    def forward(self, x):
        #print(f'original\t: {x.size()}')
        
        out = self.dropout(self.conv0(self.relu(self.bn0(x))))
        #print(f'after conv0\t: {out.size()}')
        
        out = self.dropout(self.conv1(self.relu(self.bn1(out))))
        #print(f'after conv1\t: {out.size()}')
        
        concatenated = torch.cat((out, x), dim=1)              #(3)
        #print(f'after concat\t: {concatenated.size()}')
        
        return concatenated

In [3]:
# Codeblock 3
bottleneck = Bottleneck(in_channels=64)

x = torch.randn(1, 64, 56, 56)
x = bottleneck(x)

original	: torch.Size([1, 64, 56, 56])
after conv0	: torch.Size([1, 48, 56, 56])
after conv1	: torch.Size([1, 12, 56, 56])
after concat	: torch.Size([1, 76, 56, 56])


In [5]:
# Codeblock 4
class DenseBlock(nn.Module):
    def __init__(self, in_channels, repeats):
        super().__init__()
        
        self.bottlenecks = nn.ModuleList()    #(1)
        
        for i in range(repeats):
            current_in_channels = in_channels + i*GROWTH    #(2)
            self.bottlenecks.append(Bottleneck(in_channels=current_in_channels))  #(3)
        
    def forward(self, x):
        for i, bottleneck in enumerate(self.bottlenecks):
            x = bottleneck(x)
            print(f'after bottleneck #{i}\t: {x.size()}')
        
        return x

In [6]:
# Codeblock 5
dense_block = DenseBlock(in_channels=64, repeats=6)
x = torch.randn(1, 64, 56, 56)

x = dense_block(x)

after bottleneck #0	: torch.Size([1, 76, 56, 56])
after bottleneck #1	: torch.Size([1, 88, 56, 56])
after bottleneck #2	: torch.Size([1, 100, 56, 56])
after bottleneck #3	: torch.Size([1, 112, 56, 56])
after bottleneck #4	: torch.Size([1, 124, 56, 56])
after bottleneck #5	: torch.Size([1, 136, 56, 56])


In [9]:
# Codeblock 6
class Transition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.bn   = nn.BatchNorm2d(num_features=in_channels)
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels, 
                              kernel_size=1, 
                              padding=0,
                              bias=False)
        self.dropout = nn.Dropout(p=0.2)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)    #(1)
     
    def forward(self, x):
        #print(f'original\t: {x.size()}')
        
        out = self.pool(self.dropout(self.conv(self.relu(self.bn(x)))))
        #print(f'after transition: {out.size()}')
        
        return out

In [8]:
# Codeblock 7
transition = Transition(in_channels=136, out_channels=int(136*COMPRESSION))

x = torch.randn(1, 136, 56, 56)
x = transition(x)

original	: torch.Size([1, 136, 56, 56])
after transition: torch.Size([1, 68, 28, 28])


In [10]:
# Codeblock 8a
class DenseNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.first_conv = nn.Conv2d(in_channels=3, 
                                    out_channels=64, 
                                    kernel_size=7,    #(1)
                                    stride=2,         #(2)
                                    padding=3,        #(3)
                                    bias=False)
        self.first_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  #(4)
        channel_count = 64
        

        # Dense block #0
        self.dense_block_0 = DenseBlock(in_channels=channel_count,
                                        repeats=REPEATS[0])          #(5)
        channel_count = int(channel_count+REPEATS[0]*GROWTH)         #(6)
        self.transition_0 = Transition(in_channels=channel_count, 
                                       out_channels=int(channel_count*COMPRESSION))
        channel_count = int(channel_count*COMPRESSION)               #(7)
    

        # Dense block #1
        self.dense_block_1 = DenseBlock(in_channels=channel_count, 
                                        repeats=REPEATS[1])
        channel_count = int(channel_count+REPEATS[1]*GROWTH)
        self.transition_1 = Transition(in_channels=channel_count, 
                                       out_channels=int(channel_count*COMPRESSION))
        channel_count = int(channel_count*COMPRESSION)

        # # Dense block #2
        self.dense_block_2 = DenseBlock(in_channels=channel_count, 
                                        repeats=REPEATS[2])
        channel_count = int(channel_count+REPEATS[2]*GROWTH)
        
        self.transition_2 = Transition(in_channels=channel_count, 
                                       out_channels=int(channel_count*COMPRESSION))
        channel_count = int(channel_count*COMPRESSION)

        # Dense block #3
        self.dense_block_3 = DenseBlock(in_channels=channel_count, 
                                        repeats=REPEATS[3])
        channel_count = int(channel_count+REPEATS[3]*GROWTH)
        
        
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))       #(8)
        self.fc = nn.Linear(in_features=channel_count, out_features=1000)  #(9)
        
# Codeblock 8b
    def forward(self, x):
        print(f'original\t\t: {x.size()}')
        
        x = self.first_conv(x)
        print(f'after first_conv\t: {x.size()}')
        
        x = self.first_pool(x)
        print(f'after first_pool\t: {x.size()}')
        
        x = self.dense_block_0(x)
        print(f'after dense_block_0\t: {x.size()}')
        
        x = self.transition_0(x)
        print(f'after transition_0\t: {x.size()}')

        x = self.dense_block_1(x)
        print(f'after dense_block_1\t: {x.size()}')
        
        x = self.transition_1(x)
        print(f'after transition_1\t: {x.size()}')
        
        x = self.dense_block_2(x)
        print(f'after dense_block_2\t: {x.size()}')
        
        x = self.transition_2(x)
        print(f'after transition_2\t: {x.size()}')
        
        x = self.dense_block_3(x)
        print(f'after dense_block_3\t: {x.size()}')
        
        x = self.avgpool(x)
        print(f'after avgpool\t\t: {x.size()}')
        
        x = torch.flatten(x, start_dim=1)
        print(f'after flatten\t\t: {x.size()}')
        
        x = self.fc(x)
        print(f'after fc\t\t: {x.size()}')
        
        return x

In [11]:
# Codeblock 9
densenet = DenseNet()
x = torch.randn(1, 3, 224, 224)

x = densenet(x)

original		: torch.Size([1, 3, 224, 224])
after first_conv	: torch.Size([1, 64, 112, 112])
after first_pool	: torch.Size([1, 64, 56, 56])
after bottleneck #0	: torch.Size([1, 76, 56, 56])
after bottleneck #1	: torch.Size([1, 88, 56, 56])
after bottleneck #2	: torch.Size([1, 100, 56, 56])
after bottleneck #3	: torch.Size([1, 112, 56, 56])
after bottleneck #4	: torch.Size([1, 124, 56, 56])
after bottleneck #5	: torch.Size([1, 136, 56, 56])
after dense_block_0	: torch.Size([1, 136, 56, 56])
after transition_0	: torch.Size([1, 68, 28, 28])
after bottleneck #0	: torch.Size([1, 80, 28, 28])
after bottleneck #1	: torch.Size([1, 92, 28, 28])
after bottleneck #2	: torch.Size([1, 104, 28, 28])
after bottleneck #3	: torch.Size([1, 116, 28, 28])
after bottleneck #4	: torch.Size([1, 128, 28, 28])
after bottleneck #5	: torch.Size([1, 140, 28, 28])
after bottleneck #6	: torch.Size([1, 152, 28, 28])
after bottleneck #7	: torch.Size([1, 164, 28, 28])
after bottleneck #8	: torch.Size([1, 176, 28, 28])
af