In [1]:
# Codeblock 1
import torch
import torch.nn as nn
import torch.nn.functional as F

GRIDS = [4, 2, 1]

In [2]:
# Codeblock 2
class SPP(nn.Module):
    def forward(self, x):

        pooled_outputs = []      #(1)
        for grid in GRIDS:
            pooled = F.adaptive_max_pool2d(x, output_size=(grid,grid))  #(2)
            print(f'after pool\t\t: {pooled.size()}')

            pooled = torch.flatten(pooled, start_dim=1)    #(3)
            print(f'after flatten\t\t: {pooled.size()}\n')

            pooled_outputs.append(pooled)                  #(4)

        concatenated = torch.cat(pooled_outputs, dim=1)    #(5)
        print(f'after concatenate\t: {concatenated.size()}\n')

        return concatenated

In [3]:
# Codeblock 3
spp = SPP()

x = torch.tensor([[[[5, 4, 4, 5, 3, 2, 6, 5],
                    [2, 8, 6, 3, 1, 5, 8, 1],
                    [4, 3, 7, 5, 2, 4, 7, 3],
                    [2, 1, 3, 2, 5, 5, 4, 2],
                    [4, 6, 7, 9, 5, 6, 6, 7],
                    [6, 8, 5, 9, 4, 9, 4, 1],
                    [3, 7, 4, 6, 8, 4, 4, 9],
                    [2, 4, 4, 3, 8, 2, 9, 2]]]], dtype=torch.float32)  #(1)

out = spp(x)    #(2)

after pool		: torch.Size([1, 1, 4, 4])
after flatten		: torch.Size([1, 16])

after pool		: torch.Size([1, 1, 2, 2])
after flatten		: torch.Size([1, 4])

after pool		: torch.Size([1, 1, 1, 1])
after flatten		: torch.Size([1, 1])

after concatenate	: torch.Size([1, 21])



In [4]:
# Codeblock 4
print(out)

tensor([[8., 6., 5., 8., 4., 7., 5., 7., 8., 9., 9., 7., 7., 6., 8., 9., 8., 8.,
         9., 9., 9.]])


In [5]:
# Codeblock 5
spp = SPP()

x = torch.tensor([[[4, 3, 7, 5, 2, 4, 7, 3],
                   [2, 1, 3, 2, 5, 5, 4, 2],
                   [5, 4, 4, 5, 3, 2, 6, 5],
                   [2, 8, 6, 3, 1, 5, 8, 1],
                   [4, 3, 7, 5, 2, 4, 7, 3],
                   [2, 1, 3, 2, 5, 5, 4, 2],
                   [4, 6, 7, 9, 5, 6, 6, 7],
                   [6, 8, 5, 9, 4, 9, 4, 1],
                   [3, 7, 4, 6, 8, 4, 4, 9],
                   [2, 4, 4, 3, 8, 2, 9, 2], 
                   [6, 8, 5, 9, 4, 9, 4, 1],
                   [3, 7, 4, 6, 8, 4, 4, 9]]], dtype=torch.float32)

out = spp(x)

after pool		: torch.Size([1, 4, 4])
after flatten		: torch.Size([1, 16])

after pool		: torch.Size([1, 2, 2])
after flatten		: torch.Size([1, 4])

after pool		: torch.Size([1, 1, 1])
after flatten		: torch.Size([1, 1])

after concatenate	: torch.Size([1, 21])



In [6]:
# Codeblock 6
print(out)

tensor([[5., 7., 5., 7., 8., 7., 5., 8., 8., 9., 9., 9., 8., 9., 9., 9., 8., 8.,
         9., 9., 9.]])


In [7]:
# Codeblock 7a
class ZF5_SPPNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.relu = nn.ReLU()
        
        
        self.conv1 = nn.Conv2d(in_channels=3, 
                               out_channels=96, 
                               kernel_size=7, 
                               stride=2, 
                               padding=0)    #(1)
        self.norm1 = nn.LocalResponseNorm(size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)    #(2)
        
        
        self.conv2 = nn.Conv2d(in_channels=96, 
                               out_channels=256, 
                               kernel_size=5, 
                               stride=2, 
                               padding=1)    #(3)
        self.norm2 = nn.LocalResponseNorm(size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)    #(4)
        
        
        self.conv3 = nn.Conv2d(in_channels=256, 
                               out_channels=384, 
                               kernel_size=3, 
                               stride=1, 
                               padding=1)
        
        self.conv4 = nn.Conv2d(in_channels=384, 
                               out_channels=384, 
                               kernel_size=3, 
                               stride=1, 
                               padding=1)
        
        self.conv5 = nn.Conv2d(in_channels=384, 
                               out_channels=256, 
                               kernel_size=3, 
                               stride=1, 
                               padding=1)
        
        self.spp = SPP()    #(4)
        spp_out_size = 256 * sum([grid**2 for grid in GRIDS])    #(5)
        
        self.fc6 = nn.Linear(in_features=spp_out_size, out_features=4096)
        self.dropout6 = nn.Dropout(p=0.5)    #(6)
        
        self.fc7 = nn.Linear(in_features=4096, out_features=4096)
        self.dropout7 = nn.Dropout(p=0.5)    #(7)
        
        self.fc8 = nn.Linear(in_features=4096, out_features=1000)

# Codeblock 7b
    def forward(self, x):
        print(f'original\t\t: {x.size()}')
        
        x = self.norm1(self.relu(self.conv1(x)))
        print(f'after conv1\t\t: {x.size()}')
        x = self.pool1(x)
        print(f'after pool1\t\t: {x.size()}')
        
        x = self.norm2(self.relu(self.conv2(x)))
        print(f'after conv2\t\t: {x.size()}')
        x = self.pool2(x)
        print(f'after pool2\t\t: {x.size()}')
        
        x = self.relu(self.conv3(x))
        print(f'after conv3\t\t: {x.size()}')
        
        x = self.relu(self.conv4(x))
        print(f'after conv4\t\t: {x.size()}')
        
        x = self.relu(self.conv5(x))
        print(f'after conv5\t\t: {x.size()}\n')
        
        x = self.spp(x)
        print(f'after spp\t\t: {x.size()}\n')
        
        x = self.dropout6(self.relu(self.fc6(x)))
        print(f'after fc6\t\t: {x.size()}')
        
        x = self.dropout7(self.relu(self.fc7(x)))
        print(f'after fc7\t\t: {x.size()}')
        
        x = self.fc8(x)
        print(f'after fc8\t\t: {x.size()}')
        
        return x

In [8]:
# Codeblock 8
zf5sppnet = ZF5_SPPNet()

x = torch.randn(1, 3, 224, 224)
out = zf5sppnet(x)

original		: torch.Size([1, 3, 224, 224])
after conv1		: torch.Size([1, 96, 109, 109])
after pool1		: torch.Size([1, 96, 55, 55])
after conv2		: torch.Size([1, 256, 27, 27])
after pool2		: torch.Size([1, 256, 13, 13])
after conv3		: torch.Size([1, 384, 13, 13])
after conv4		: torch.Size([1, 384, 13, 13])
after conv5		: torch.Size([1, 256, 13, 13])

after pool		: torch.Size([1, 256, 4, 4])
after flatten		: torch.Size([1, 4096])

after pool		: torch.Size([1, 256, 2, 2])
after flatten		: torch.Size([1, 1024])

after pool		: torch.Size([1, 256, 1, 1])
after flatten		: torch.Size([1, 256])

after concatenate	: torch.Size([1, 5376])

after spp		: torch.Size([1, 5376])

after fc6		: torch.Size([1, 4096])
after fc7		: torch.Size([1, 4096])
after fc8		: torch.Size([1, 1000])
