In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
class LocatlizationNetwork(nn.Module):
    # 定位出K個基準點
    # 一個點由(x,y)定義，所以會有2K個輸出
    
    def __init__(self, K, input_channel):
        super(LocatlizationNetwork, self).__init__()
        self.K = K
        self.input_channel = input_channel
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=self.input_channel, out_channels=64, 
                      kernel_size=(3,3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(64, 128, (3,3), 1, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(128, 256, (3,3), 1, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(256, 512, (3,3), 1, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.AdaptiveAvgPool2d(1)
            )
        
        self.location_fc1 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(True)
        )
        self.location_fc2 = nn.Linear(1024, 2*K)
        self.tanh = nn.Tanh() # 範圍[-1, 1]
        
        # initialized fc2 weight 
        self.location_fc2.weight.data.fill_(0)
        
        # initialized fc2 weight and bias
        pt_x = np.linspace(-1.0, 1.0, int(K/2)) # 產生基準點的x座標
        pt_y_top = np.linspace(0.0, -1.0, int(K/2)) # 產生上半部基準點的y座標
        pt_y_bottom = np.linspace(1, 0.0, int(K/2)) # 產生下半部基準點的y座標
        pt_top = np.stack([pt_x, pt_y_top], axis=1) # 上半部基準點
        pt_bottom = np.stack([pt_x, pt_y_bottom], axis=1) # 下半部基準點
        init_bias = np.concatenate([pt_top, pt_bottom], axis=0)
        self.location_fc2.bias.data = torch.from_numpy(init_bias).float().view(-1)
        
    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size(0), -1)
        out = self.location_fc1(out)
        out = self.location_fc2(out)
        out = self.tanh(out)
        
        return out

In [3]:
model = LocatlizationNetwork(20,3)

In [4]:
from torchsummary import summary

summary(model, (3, 100, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 64, 100, 32]           1,728
       BatchNorm2d-2          [-1, 64, 100, 32]             128
              ReLU-3          [-1, 64, 100, 32]               0
         MaxPool2d-4           [-1, 64, 50, 16]               0
            Conv2d-5          [-1, 128, 50, 16]          73,728
       BatchNorm2d-6          [-1, 128, 50, 16]             256
              ReLU-7          [-1, 128, 50, 16]               0
         MaxPool2d-8           [-1, 128, 25, 8]               0
            Conv2d-9           [-1, 256, 25, 8]         294,912
      BatchNorm2d-10           [-1, 256, 25, 8]             512
             ReLU-11           [-1, 256, 25, 8]               0
        MaxPool2d-12           [-1, 256, 12, 4]               0
           Conv2d-13           [-1, 512, 12, 4]       1,179,648
      BatchNorm2d-14           [-1, 512

In [5]:
# 測試一個batch中有兩張100*32的影像輸入到localization network後的結果
input = torch.randn(2, 3, 100, 32)

output = model(input)

In [6]:
output[0]

tensor([-0.7616,  0.0000, -0.6514, -0.1107, -0.5047, -0.2186, -0.3215, -0.3215,
        -0.1107, -0.4173,  0.1107, -0.5047,  0.3215, -0.5828,  0.5047, -0.6514,
         0.6514, -0.7108,  0.7616, -0.7616, -0.7616,  0.7616, -0.6514,  0.7108,
        -0.5047,  0.6514, -0.3215,  0.5828, -0.1107,  0.5047,  0.1107,  0.4173,
         0.3215,  0.3215,  0.5047,  0.2186,  0.6514,  0.1107,  0.7616,  0.0000],
       grad_fn=<SelectBackward>)