In [1]:
# Codeblock 1
import torch
import torch.nn as nn

In [5]:
# Codeblock 2
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=3, padding=1)
        
    def forward(self, x):
        #print(f'original\t: {x.size()}\n')
        
        x = self.relu(self.conv1(x))
        #print(f'after conv1\t: {x.size()}')
        
        x = self.maxpool(x)
        #print(f'after maxpool\t: {x.size()}\n')
        
        x = self.relu(self.conv2(x))
        #print(f'after conv2\t: {x.size()}')
        
        x = self.maxpool(x)
        #print(f'after maxpool\t: {x.size()}\n')
        
        c2 = x.clone()             #(1)
        
        x = self.relu(self.conv3(x))
        #print(f'after conv3\t: {x.size()}')
        
        x = self.maxpool(x)
        #print(f'after maxpool\t: {x.size()}\n')
        
        c3 = x.clone()             #(2)
        
        x = self.relu(self.conv4(x))
        #print(f'after conv4\t: {x.size()}')
        
        x = self.maxpool(x)
        #print(f'after maxpool\t: {x.size()}\n')
        
        c4 = x.clone()             #(3)
        
        x = self.relu(self.conv5(x))
        #print(f'after conv5\t: {x.size()}')
        
        c5 = self.maxpool(x)
        #print(f'after maxpool\t: {c5.size()}\n')
        
        return c2, c3, c4, c5      #(4)

In [3]:
# Codeblock 3
cnn = CNN()

x = torch.randn(1, 3, 224, 224)
out_cnn = cnn(x)

original	: torch.Size([1, 3, 224, 224])

after conv1	: torch.Size([1, 64, 224, 224])
after maxpool	: torch.Size([1, 64, 112, 112])

after conv2	: torch.Size([1, 256, 112, 112])
after maxpool	: torch.Size([1, 256, 56, 56])

after conv3	: torch.Size([1, 512, 56, 56])
after maxpool	: torch.Size([1, 512, 28, 28])

after conv4	: torch.Size([1, 1024, 28, 28])
after maxpool	: torch.Size([1, 1024, 14, 14])

after conv5	: torch.Size([1, 2048, 14, 14])
after maxpool	: torch.Size([1, 2048, 7, 7])



In [4]:
# Codeblock 4
c2, c3, c4, c5 = out_cnn

print(c2.shape)
print(c3.shape)
print(c4.shape)
print(c5.shape)

torch.Size([1, 256, 56, 56])
torch.Size([1, 512, 28, 28])
torch.Size([1, 1024, 14, 14])
torch.Size([1, 2048, 7, 7])


In [6]:
# Codeblock 5
class FPN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')    #(1)
        
        self.lateral_c5 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1)
        self.lateral_c4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
        self.lateral_c3 = nn.Conv2d(in_channels=512,  out_channels=256, kernel_size=1)
        self.lateral_c2 = nn.Conv2d(in_channels=256,  out_channels=256, kernel_size=1)
        
        self.smooth_m4  = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.smooth_m3  = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.smooth_m2  = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        
    def forward(self, c2, c3, c4, c5):
        m5 = self.lateral_c5(c5)
        p5 = m5
        
        m4 = self.upsample(m5) + self.lateral_c4(c4)
        p4 = self.smooth_m4(m4)
        
        m3 = self.upsample(m4) + self.lateral_c3(c3)
        p3 = self.smooth_m3(m3)
        
        m2 = self.upsample(m3) + self.lateral_c2(c2)
        p2 = self.smooth_m2(m2)
        
        return p2, p3, p4, p5

In [7]:
# Codeblock 6
fpn = FPN()

out_fpn = fpn(c2, c3, c4, c5)
p2, p3, p4, p5 = out_fpn

print(p2.shape)
print(p3.shape)
print(p4.shape)
print(p5.shape)

torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 7, 7])


In [8]:
# Codeblock 7
NUM_ANCHORS = 3

class RPN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.intermediate = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        
        self.cls = nn.Conv2d(in_channels=256, out_channels=NUM_ANCHORS*2, kernel_size=1)
        self.reg = nn.Conv2d(in_channels=256, out_channels=NUM_ANCHORS*4, kernel_size=1)
    
    def forward(self, x):
        x = self.intermediate(x)
        
        objectness_scores = self.cls(x)
        bbox_regressions  = self.reg(x)
        
        return objectness_scores, bbox_regressions

In [9]:
# Codeblock 8
rpn = RPN()

p2_objectness, p2_bbox = rpn(p2)

print(p2_objectness.shape)
print(p2_bbox.shape)

torch.Size([1, 6, 56, 56])
torch.Size([1, 12, 56, 56])


In [10]:
# Codeblock 9
class DetectionModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cnn = CNN()
        self.fpn = FPN()
        self.rpn = RPN()
        
    def forward(self, x):
        
        c2, c3, c4, c5 = self.cnn(x)                 #(1)
        p2, p3, p4, p5 = self.fpn(c2, c3, c4, c5)    #(2)
        
        p2_pred = self.rpn(p2)        #(3)
        p3_pred = self.rpn(p3)
        p4_pred = self.rpn(p4)
        p5_pred = self.rpn(p5)        #(4)
        
        return p2_pred, p3_pred, p4_pred, p5_pred

In [11]:
# Codeblock 10
detection_model = DetectionModel()

x = torch.randn(1, 3, 224, 224)     #(1)
p2_pred, p3_pred, p4_pred, p5_pred = detection_model(x)  #(2)

p2_objectness, p2_bbox = p2_pred    #(3)
p3_objectness, p3_bbox = p3_pred
p4_objectness, p4_bbox = p4_pred
p5_objectness, p5_bbox = p5_pred    #(4)
        
print(p2_objectness.shape)
print(p3_objectness.shape)
print(p4_objectness.shape)
print(p5_objectness.shape)
print()

print(p2_bbox.shape)
print(p3_bbox.shape)
print(p4_bbox.shape)
print(p5_bbox.shape)

torch.Size([1, 6, 56, 56])
torch.Size([1, 6, 28, 28])
torch.Size([1, 6, 14, 14])
torch.Size([1, 6, 7, 7])

torch.Size([1, 12, 56, 56])
torch.Size([1, 12, 28, 28])
torch.Size([1, 12, 14, 14])
torch.Size([1, 12, 7, 7])
