In [1]:
# Codeblock 1
import torch
import torch.nn as nn

In [2]:
# Codeblock 2
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=3, padding=1)
        
    def forward(self, x):
        print(f'original\t: {x.size()}\n')
        
        x = self.relu(self.conv1(x))
        print(f'after conv1\t: {x.size()}')
        
        x = self.maxpool(x)
        print(f'after pool\t: {x.size()}\n')
        
        x = self.relu(self.conv2(x))
        print(f'after conv2\t: {x.size()}')
        
        x = self.maxpool(x)
        print(f'after pool (c2)\t: {x.size()}\n')
        
        c2 = x.clone()
        
        x = self.relu(self.conv3(x))
        print(f'after conv3\t: {x.size()}')
        
        x = self.maxpool(x)
        print(f'after pool (c3)\t: {x.size()}\n')
        
        c3 = x.clone()
        
        x = self.relu(self.conv4(x))
        print(f'after conv4\t: {x.size()}')
        
        x = self.maxpool(x)
        print(f'after pool (c4)\t: {x.size()}\n')
        
        c4 = x.clone()
        
        x = self.relu(self.conv5(x))
        print(f'after conv5\t: {x.size()}')
        
        c5 = self.maxpool(x)
        print(f'after pool (c5)\t: {c5.size()}\n')
        
        return c2, c3, c4, c5

In [3]:
# Codeblock 3
cnn = CNN()

x = torch.randn(1, 3, 224, 224)
out_cnn = cnn(x)

c2, c3, c4, c5 = out_cnn

original	: torch.Size([1, 3, 224, 224])

after conv1	: torch.Size([1, 64, 224, 224])
after pool	: torch.Size([1, 64, 112, 112])

after conv2	: torch.Size([1, 256, 112, 112])
after pool (c2)	: torch.Size([1, 256, 56, 56])

after conv3	: torch.Size([1, 512, 56, 56])
after pool (c3)	: torch.Size([1, 512, 28, 28])

after conv4	: torch.Size([1, 1024, 28, 28])
after pool (c4)	: torch.Size([1, 1024, 14, 14])

after conv5	: torch.Size([1, 2048, 14, 14])
after pool (c5)	: torch.Size([1, 2048, 7, 7])



In [4]:
# Codeblock 4
class FPN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')    #(1)
        
        self.lateral_c5 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1)
        self.lateral_c4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
        self.lateral_c3 = nn.Conv2d(in_channels=512,  out_channels=256, kernel_size=1)
        self.lateral_c2 = nn.Conv2d(in_channels=256,  out_channels=256, kernel_size=1)
        
        self.smooth_m4  = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.smooth_m3  = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.smooth_m2  = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        
    def forward(self, c2, c3, c4, c5):
        m5 = self.lateral_c5(c5)
        p5 = m5
        
        m4 = self.upsample(m5) + self.lateral_c4(c4)
        p4 = self.smooth_m4(m4)
        
        m3 = self.upsample(m4) + self.lateral_c3(c3)
        p3 = self.smooth_m3(m3)
        
        m2 = self.upsample(m3) + self.lateral_c2(c2)
        p2 = self.smooth_m2(m2)
        
        return p2, p3, p4, p5

In [5]:
# Codeblock 5
fpn = FPN()

out_fpn = fpn(c2, c3, c4, c5)
p2, p3, p4, p5 = out_fpn

print(f'p2: {p2.size()}')
print(f'p3: {p3.size()}')
print(f'p4: {p4.size()}')
print(f'p5: {p5.size()}')

p2: torch.Size([1, 256, 56, 56])
p3: torch.Size([1, 256, 28, 28])
p4: torch.Size([1, 256, 14, 14])
p5: torch.Size([1, 256, 7, 7])


In [8]:
# Codeblock 6
class PANet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.downsample_n2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.downsample_n3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.downsample_n4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        
        self.conv_n2down_p3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv_n3down_p4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv_n4down_p5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        
        self.relu = nn.ReLU()

    def forward(self, p2, p3, p4, p5):
        
        n2 = p2                                         #(1)
        #print(f'n2\t\t: {n2.size()}\n')
        
        #######################################
        
        n2down = self.relu(self.downsample_n2(n2))      #(2)
        #print(f'n2 downsampled\t: {n2down.size()}')

        n2down_p3 = n2down + p3                         #(3)
        #print(f'after sum\t: {n2down_p3.size()}')
        
        n3 = self.relu(self.conv_n2down_p3(n2down_p3))  #(4)
        #print(f'n3\t\t: {n3.size()}\n')
        
        #######################################
        
        n3down = self.relu(self.downsample_n3(n3))
        #print(f'n3 downsampled\t: {n3down.size()}')
        
        n3down_p4 = n3down + p4
        #print(f'after sum\t: {n3down_p4.size()}')
        
        n4 = self.relu(self.conv_n3down_p4(n3down_p4))
        #print(f'n4\t\t: {n4.size()}\n')
        
        #######################################
        
        n4down = self.relu(self.downsample_n4(n4))
        #print(f'n4 downsampled\t: {n4down.size()}')
        
        n4down_p5 = n4down + p5
        #print(f'after sum\t: {n4down_p5.size()}')
        
        n5 = self.relu(self.conv_n4down_p5(n4down_p5))
        #print(f'n5\t\t: {n5.size()}')
        
        return n2, n3, n4, n5

In [7]:
# Codeblock 7
panet = PANet()

out_panet = panet(p2, p3, p4, p5)
n2, n3, n4, n5 = out_panet

n2		: torch.Size([1, 256, 56, 56])

n2 downsampled	: torch.Size([1, 256, 28, 28])
after sum	: torch.Size([1, 256, 28, 28])
n3		: torch.Size([1, 256, 28, 28])

n3 downsampled	: torch.Size([1, 256, 14, 14])
after sum	: torch.Size([1, 256, 14, 14])
n4		: torch.Size([1, 256, 14, 14])

n4 downsampled	: torch.Size([1, 256, 7, 7])
after sum	: torch.Size([1, 256, 7, 7])
n5		: torch.Size([1, 256, 7, 7])


In [9]:
# Codeblock 8
class BackboneNeck(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cnn   = CNN()
        self.fpn   = FPN()
        self.panet = PANet()
    
    def forward(self, x):
        c2, c3, c4, c5 = self.cnn(x)
        print(f'c2: {c2.size()}')
        print(f'c3: {c3.size()}')
        print(f'c4: {c4.size()}')
        print(f'c5: {c5.size()}\n')
        
        p2, p3, p4, p5 = self.fpn(c2, c3, c4, c5)
        print(f'p2: {p2.size()}')
        print(f'p3: {p3.size()}')
        print(f'p4: {p4.size()}')
        print(f'p5: {p5.size()}\n')
        
        n2, n3, n4, n5 = self.panet(p2, p3, p4, p5)
        print(f'n2: {n2.size()}')
        print(f'n3: {n3.size()}')
        print(f'n4: {n4.size()}')
        print(f'n5: {n5.size()}')
        
        return n2, n3, n4, n5

In [10]:
# Codeblock 9
backbone_neck = BackboneNeck()

x = torch.randn(1, 3, 224, 224)
n2, n3, n4, n5 = backbone_neck(x)

original	: torch.Size([1, 3, 224, 224])

after conv1	: torch.Size([1, 64, 224, 224])
after pool	: torch.Size([1, 64, 112, 112])

after conv2	: torch.Size([1, 256, 112, 112])
after pool (c2)	: torch.Size([1, 256, 56, 56])

after conv3	: torch.Size([1, 512, 56, 56])
after pool (c3)	: torch.Size([1, 512, 28, 28])

after conv4	: torch.Size([1, 1024, 28, 28])
after pool (c4)	: torch.Size([1, 1024, 14, 14])

after conv5	: torch.Size([1, 2048, 14, 14])
after pool (c5)	: torch.Size([1, 2048, 7, 7])

c2: torch.Size([1, 256, 56, 56])
c3: torch.Size([1, 512, 28, 28])
c4: torch.Size([1, 1024, 14, 14])
c5: torch.Size([1, 2048, 7, 7])

p2: torch.Size([1, 256, 56, 56])
p3: torch.Size([1, 256, 28, 28])
p4: torch.Size([1, 256, 14, 14])
p5: torch.Size([1, 256, 7, 7])

n2: torch.Size([1, 256, 56, 56])
n3: torch.Size([1, 256, 28, 28])
n4: torch.Size([1, 256, 14, 14])
n5: torch.Size([1, 256, 7, 7])
