In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models

In [56]:
search = torch.randn(1, 3, 255, 255)
target = torch.randn(1, 3, 127, 127)

In [57]:
IN_CHANNELS = 3

In [90]:
base_model = models.resnet50(pretrained=True)
base_layers = list(base_model.children())
base_layers[0:4]

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)]

In [100]:
len(base_layers)

10

In [59]:
class SearchBone(nn.Module):
    def __init__(self, backbone_channels=1024, out_channels=256):
        super(SearchBone, self).__init__()
        self.backbone_channels = backbone_channels
        self.out_channels = out_channels

        self.backbone = nn.Sequential(*base_layers[0:7])
        self.adjust = nn.Conv2d(backbone_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.backbone(x)
        x = self.adjust(x)
        return x

In [98]:
mm = SearchBone()
mm

SearchBone(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (downsample): Sequential(
          (0): Conv2d(64, 256, k

In [63]:
x5 = mm(search)
print('x5.shape:', x5.shape)

x5.shape: torch.Size([1, 256, 16, 16])


In [96]:
class TargetBone(nn.Module):
    def __init__(self, backbone_channels=1024, out_channels=256):
        super(TargetBone, self).__init__()
        self.backbone_channels = backbone_channels
        self.out_channels = out_channels

        self.conv1 = nn.Sequential(*base_layers[0:2])
        self.conv2_x = nn.Sequential(*base_layers[2:5])
        self.conv3_x = nn.Sequential(*base_layers[5:6])
        self.conv4_x = nn.Sequential(*base_layers[6:7])
        self.adjust = nn.Conv2d(backbone_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        print(x.shape)
        conv1 = self.conv1(x)
        print(conv1.shape)
        conv2_x = self.conv2_x(conv1)
        conv3_x = self.conv3_x(conv2_x)
        conv4_x = self.conv4_x(conv3_x)
        adjust = self.adjust(conv4_x)
        return conv1, conv2_x, conv3_x, conv4_x, adjust

In [97]:
mm = TargetBone()
x1, x2, x3, x4, x5 = mm(target)
print('x1.shape:', x1.shape, 'x2.shape:', x2.shape, 'x3.shape:', x3.shape, 'x4.shape:', x4.shape, 'x5.shape:', x5.shape)

torch.Size([1, 3, 127, 127])
torch.Size([1, 64, 64, 64])
x1.shape: torch.Size([1, 64, 64, 64]) x2.shape: torch.Size([1, 256, 32, 32]) x3.shape: torch.Size([1, 512, 16, 16]) x4.shape: torch.Size([1, 1024, 8, 8]) x5.shape: torch.Size([1, 256, 8, 8])


In [39]:
class PoolModule(nn.Module):
    def __init__(self, kernel_size=(17, 17, 256)):
        super(PoolModule, self).__init__()
        self.maxpool3d = nn.MaxPool3d(kernel_size)

    def forward(self, x):
        x = x.unsqueeze(0)
        x = self.maxpool3d(x)
        x = x.squeeze()
        return x

In [43]:
class ModelHead(nn.Module):
    def __init__(self, ):
        super(ModelHead, self).__init__()
        
        self.search_bone = SearchBone()
        self.target_bone = TargetBone()
    
    def Correlation_func(x, kernel):
        x = x.view(1, -1, x.size(2), x.size(3))  # 1 * (b*c) * k * k
        kernel = kernel.view(-1, 1, kernel.size(2), kernel.size(3))  # (b*c) * 1 * H * W
        out = F.conv2d(x, kernel, groups=x.size(1))
        out = out.view(1, x.size(1), out.size(2), out.size(3))
        return out
        

    def forward(self, search, target):
        '''
        TODO: 
            permute for time sequence: (batch, time, channels, input_size, input_size) -->(batch*time, channels, input_size, input_size)
        '''
        search = self.search_bone(search)
        _, _, _, _, target = self.target_bone(target)
        corr_feat = Correlation_func(search, target)
        return corr_feat

In [44]:
model = ModelHead()

In [46]:
model(search, target).shape

torch.Size([1, 256, 9, 9])