In [1]:
# Codeblock 1
import torch
import torch.nn as nn

GROWTH          = 12
CHANNEL_POOLING = 0.8
COMPRESSION     = 0.5
REPEATS         = [6, 12, 24, 16]

In [4]:
# Codeblock 2
class Bottleneck(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)
        
        self.bn0   = nn.BatchNorm2d(num_features=in_channels)
        self.conv0 = nn.Conv2d(in_channels=in_channels, 
                               out_channels=GROWTH*4,          
                               kernel_size=1, 
                               padding=0, 
                               bias=False)
        
        self.bn1   = nn.BatchNorm2d(num_features=GROWTH*4)
        self.conv1 = nn.Conv2d(in_channels=GROWTH*4, 
                               out_channels=GROWTH,            
                               kernel_size=3, 
                               padding=1, 
                               bias=False)
    
    def forward(self, x):
        #print(f'original\t: {x.size()}')
        
        out = self.dropout(self.conv0(self.relu(self.bn0(x))))
        #print(f'after conv0\t: {out.size()}')
        
        out = self.dropout(self.conv1(self.relu(self.bn1(out))))
        #print(f'after conv1\t: {out.size()}')
        
        concatenated = torch.cat((out, x), dim=1)              
        #print(f'after concat\t: {concatenated.size()}')
        
        return concatenated

In [3]:
# Codeblock 3
bottleneck = Bottleneck(in_channels=32)

x = torch.randn(1, 32, 56, 56)
x = bottleneck(x)

original	: torch.Size([1, 32, 56, 56])
after conv0	: torch.Size([1, 48, 56, 56])
after conv1	: torch.Size([1, 12, 56, 56])
after concat	: torch.Size([1, 44, 56, 56])


In [7]:
# Codeblock 4
class DenseBlock(nn.Module):
    def __init__(self, in_channels, repeats):
        super().__init__()
        self.bottlenecks = nn.ModuleList()
        
        for i in range(repeats):
            current_in_channels = in_channels + i * GROWTH
            self.bottlenecks.append(Bottleneck(in_channels=current_in_channels))
        
    def forward(self, x):
        #print(f'original\t\t\t: {x.size()}')
        
        for i, bottleneck in enumerate(self.bottlenecks):
            x = bottleneck(x)
            print(f'after bottleneck #{i}\t\t: {x.size()}')
            
        return x

In [6]:
# Codeblock 5
dense_block = DenseBlock(in_channels=32, repeats=6)
x = torch.randn(1, 32, 56, 56)

x = dense_block(x)

original			: torch.Size([1, 32, 56, 56])
after bottleneck #0		: torch.Size([1, 44, 56, 56])
after bottleneck #1		: torch.Size([1, 56, 56, 56])
after bottleneck #2		: torch.Size([1, 68, 56, 56])
after bottleneck #3		: torch.Size([1, 80, 56, 56])
after bottleneck #4		: torch.Size([1, 92, 56, 56])
after bottleneck #5		: torch.Size([1, 104, 56, 56])


In [10]:
# Codeblock 6
class FirstTransition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.bn   = nn.BatchNorm2d(num_features=in_channels)
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,   #(1)
                              kernel_size=1,               #(2)
                              padding=0,
                              bias=False)
        self.dropout = nn.Dropout(p=0.2)
     
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        out = self.dropout(self.conv(self.relu(self.bn(x))))
        #print(f'after first_transition\t: {out.size()}')
        
        return out

In [9]:
# Codeblock 7
first_transition = FirstTransition(in_channels=104, 
                                   out_channels=int(104*CHANNEL_POOLING)) #(1)

x = torch.randn(1, 104, 56, 56)
x = first_transition(x)

original		: torch.Size([1, 104, 56, 56])
after first_transition	: torch.Size([1, 83, 56, 56])


In [13]:
# Codeblock 8
class SecondTransition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.bn   = nn.BatchNorm2d(num_features=in_channels)
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels, 
                              kernel_size=1, 
                              padding=0,
                              bias=False)
        self.dropout = nn.Dropout(p=0.2)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)    #(1)
     
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')

        out = self.pool(self.dropout(self.conv(self.relu(self.bn(x)))))
        #print(f'after second_transition\t: {out.size()}')
        
        return out

In [12]:
# Codeblock 9
second_transition = SecondTransition(in_channels=115, 
                                     out_channels=int(115*COMPRESSION))

x = torch.randn(1, 115, 56, 56)
x = second_transition(x)

original		: torch.Size([1, 115, 56, 56])
after second_transition	: torch.Size([1, 57, 28, 28])


In [14]:
# Codeblock 10a
class CSPDenseNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.first_conv = nn.Conv2d(in_channels=3,         #(1)
                                    out_channels=64, 
                                    kernel_size=7,    
                                    stride=2,         
                                    padding=3,        
                                    bias=False)
        self.first_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  #(2)
        channel_count = 64
        
        
        
        ##### Stage 0
        self.dense_block_0 = DenseBlock(in_channels=channel_count//2, 
                                        repeats=REPEATS[0])
        
        self.first_transition_0 = FirstTransition(in_channels=(channel_count//2)+(REPEATS[0]*GROWTH), 
                                                  out_channels=int(((channel_count//2)+(REPEATS[0]*GROWTH))*CHANNEL_POOLING))
        
        channel_count = (channel_count - (channel_count//2)) + int(((channel_count//2)+(REPEATS[0]*GROWTH))*CHANNEL_POOLING)
        
        self.second_transition_0 = SecondTransition(in_channels=channel_count, 
                                                  out_channels=int(channel_count*COMPRESSION))
        
        channel_count = int(channel_count*COMPRESSION)
        #####
        
        
        ##### Stage 1
        self.dense_block_1 = DenseBlock(in_channels=channel_count//2, 
                                        repeats=REPEATS[1])
        
        self.first_transition_1 = FirstTransition(in_channels=(channel_count//2)+(REPEATS[1]*GROWTH), 
                                                  out_channels=int(((channel_count//2)+(REPEATS[1]*GROWTH))*CHANNEL_POOLING))
        
        channel_count = (channel_count - (channel_count//2)) + int(((channel_count//2)+(REPEATS[1]*GROWTH))*CHANNEL_POOLING)
        
        self.second_transition_1 = SecondTransition(in_channels=channel_count, 
                                                  out_channels=int(channel_count*COMPRESSION))
        
        channel_count = int(channel_count*COMPRESSION)
        #####
        
        
        ##### Stage 2
        self.dense_block_2 = DenseBlock(in_channels=channel_count//2, 
                                        repeats=REPEATS[2])
        
        self.first_transition_2 = FirstTransition(in_channels=(channel_count//2)+(REPEATS[2]*GROWTH), 
                                                  out_channels=int(((channel_count//2)+(REPEATS[2]*GROWTH))*CHANNEL_POOLING))
        
        channel_count = (channel_count - (channel_count//2)) + int(((channel_count//2)+(REPEATS[2]*GROWTH))*CHANNEL_POOLING)
        
        self.second_transition_2 = SecondTransition(in_channels=channel_count, 
                                                  out_channels=int(channel_count*COMPRESSION))
        
        channel_count = int(channel_count*COMPRESSION)
        #####
        
        
        ##### Stage 3
        self.dense_block_3 = DenseBlock(in_channels=channel_count//2, 
                                        repeats=REPEATS[3])
        
        self.first_transition_3 = FirstTransition(in_channels=(channel_count//2)+(REPEATS[3]*GROWTH), 
                                                  out_channels=int(((channel_count//2)+(REPEATS[3]*GROWTH))*CHANNEL_POOLING))
        
        channel_count = (channel_count - (channel_count//2)) + int(((channel_count//2)+(REPEATS[3]*GROWTH))*CHANNEL_POOLING)
        #####
        
        
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))             #(3)
        self.fc = nn.Linear(in_features=channel_count, out_features=1000)  #(4)
        
# Codeblock 10b
    def split_channels(self, x):

        channel_count = x.size(1)

        if channel_count%2 != 0:
            split_size_2 = channel_count // 2            #(1)
            split_size_1 = channel_count - split_size_2  #(2)
            return torch.split(x, [split_size_1, split_size_2], dim=1)  #(3)

        else:
            return torch.split(x, channel_count // 2, dim=1)            #(4)
        
# Codeblock 10c
    def forward(self, x):
        print(f'original\t\t\t: {x.size()}')
        
        x = self.first_conv(x)
        print(f'after first_conv\t\t: {x.size()}')
        
        x = self.first_pool(x)      #(1)
        print(f'after first_pool\t\t: {x.size()}\n')
        
        
        
        ##### Stage 0
        part1, part2 = self.split_channels(x)    #(2)
        print(f'part1\t\t\t\t: {part1.size()}')
        print(f'part2\t\t\t\t: {part2.size()}')
        
        part2 = self.dense_block_0(part2)        #(3)
        print(f'part2 after dense block 0\t: {part2.size()}')
        
        part2 = self.first_transition_0(part2)   #(4)
        print(f'part2 after first trans 0\t: {part2.size()}')
        
        x = torch.cat((part1, part2), dim=1)     #(xx)
        print(f'after concatenate\t\t: {x.size()}')
        
        x = self.second_transition_0(x)          #(xx)
        print(f'after second transition 0\t: {x.size()}\n')
        
        
        
        ##### Stage 1
        part1, part2 = self.split_channels(x)
        print(f'part1\t\t\t\t: {part1.size()}')
        print(f'part2\t\t\t\t: {part2.size()}')
        
        part2 = self.dense_block_1(part2)
        print(f'part2 after dense block 1\t: {part2.size()}')
        
        part2 = self.first_transition_1(part2)
        print(f'part2 after first trans 1\t: {part2.size()}')
        
        x = torch.cat((part1, part2), dim=1)
        print(f'after concatenate\t\t: {x.size()}')
        
        x = self.second_transition_1(x)
        print(f'after second transition 1\t: {x.size()}\n')
        
        
        
        ##### Stage 2
        part1, part2 = self.split_channels(x)
        print(f'part1\t\t\t\t: {part1.size()}')
        print(f'part2\t\t\t\t: {part2.size()}')
        
        part2 = self.dense_block_2(part2)
        print(f'part2 after dense block 2\t: {part2.size()}')
        
        part2 = self.first_transition_2(part2)
        print(f'part2 after first trans 2\t: {part2.size()}')
        
        x = torch.cat((part1, part2), dim=1)
        print(f'after concatenate\t\t: {x.size()}')
        
        x = self.second_transition_2(x)
        print(f'after second transition 2\t: {x.size()}\n')
        
        
        
        ##### Stage 3
        part1, part2 = self.split_channels(x)
        print(f'part1\t\t\t\t: {part1.size()}')
        print(f'part2\t\t\t\t: {part2.size()}')
        
        part2 = self.dense_block_3(part2)
        print(f'part2 after dense block 2\t: {part2.size()}')
        
        part2 = self.first_transition_3(part2)
        print(f'part2 after first trans 2\t: {part2.size()}')
        
        x = torch.cat((part1, part2), dim=1)
        print(f'after concatenate\t\t: {x.size()}\n')
        
        
        
        x = self.avgpool(x)
        print(f'after avgpool\t\t\t: {x.size()}')
        
        x = torch.flatten(x, start_dim=1)
        print(f'after flatten\t\t\t: {x.size()}')
        
        x = self.fc(x)
        print(f'after fc\t\t\t: {x.size()}')
        
        return x

In [15]:
# Codeblock 11
cspdensenet = CSPDenseNet()

x = torch.randn(1, 3, 224, 224)
x = cspdensenet(x)

original			: torch.Size([1, 3, 224, 224])
after first_conv		: torch.Size([1, 64, 112, 112])
after first_pool		: torch.Size([1, 64, 56, 56])

part1				: torch.Size([1, 32, 56, 56])
part2				: torch.Size([1, 32, 56, 56])
after bottleneck #0		: torch.Size([1, 44, 56, 56])
after bottleneck #1		: torch.Size([1, 56, 56, 56])
after bottleneck #2		: torch.Size([1, 68, 56, 56])
after bottleneck #3		: torch.Size([1, 80, 56, 56])
after bottleneck #4		: torch.Size([1, 92, 56, 56])
after bottleneck #5		: torch.Size([1, 104, 56, 56])
part2 after dense block 0	: torch.Size([1, 104, 56, 56])
part2 after first trans 0	: torch.Size([1, 83, 56, 56])
after concatenate		: torch.Size([1, 115, 56, 56])
after second transition 0	: torch.Size([1, 57, 28, 28])

part1				: torch.Size([1, 29, 28, 28])
part2				: torch.Size([1, 28, 28, 28])
after bottleneck #0		: torch.Size([1, 40, 28, 28])
after bottleneck #1		: torch.Size([1, 52, 28, 28])
after bottleneck #2		: torch.Size([1, 64, 28, 28])
after bottleneck #3		: 