# 基于细粒度特征的热红外目标追踪 Part 3

## 编码实现细粒度特征提取网络结构

这一部分主要是通过构建深层的神经网络，从热红外图像中提取上一部分所定义的细粒度特征。并将细粒度特征网络结构整合为独立的框架，方便下一步骤进行嵌入。

*Reference:*

Liu Q, Li X, He Z, et al. Multi-Task Driven Feature Models for Thermal InfraredTracking[C]//2020 Thirty-Fourth AAAI Conference on Artificial Intelligence. AAAI, 2020.

Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[C]//Advances in neural information processing systems. 2017: 5998-6008.

Wang X, Girshick R, Gupta A, et al. Non-local neural networks[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7794-7803.

Cao Y, Xu J, Lin S, et al. Gcnet: Non-local networks meet squeeze-excitation networks and beyond[C]//Proceedings of the IEEE International Conference on Computer Vision Workshops. 2019: 0-0.

回顾上节，本文提出的细粒度特征提取网络如下所示：

![encoderDecoder.png](./images4paper/encoderDecoder.png)

图中，$\oplus$表示单位加操作。其中自注意力模块的结构如下图所示：

![attentionBlockV2.png](./images4paper/attentionBlockV2.png)

图中，$\otimes$表示batch上的矩阵乘法，$\oplus$表示单位加操作。

基于PyTorch，我们实现上文所提出的细粒度特征提取网络：

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = in_channels

        self.conv = nn.Conv2d(self.in_channels, 1, kernel_size=1)
        self.softmax = nn.Softmax(dim=2)

        self.transformer = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),
            nn.LayerNorm([self.out_channels, 1, 1]),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.in_channels, kernel_size=1)
        )
        nn.init.constant_(self.transformer[-1].weight, 0)
        nn.init.constant_(self.transformer[-1].bias, 0)

    def feature_extractor(self, x):
        batch, channel, height, width = x.size()
        
        route_1 = x
        # [N, C, H * W]
        route_1 = route_1.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        route_1 = route_1.unsqueeze(1)
        
        # [N, 1, H, W]
        route_2 = self.conv(x)
        # [N, 1, H * W]
        route_2 = route_2.view(batch, 1, height * width)
        # [N, 1, H * W]
        route_2 = self.softmax(route_2)
        # [N, 1, H * W, 1]
        route_2 = route_2.unsqueeze(-1)
        
        # [N, 1, C, 1]
        feature = torch.matmul(route_1, route_2)
        # [N, C, 1, 1]
        feature = feature.view(batch, channel, 1, 1)

        return feature

    def forward(self, x):
        # [N, C, 1, 1]
        feature = self.feature_extractor(x)
        fine_grained = self.transformer(feature)
        out = x + fine_grained

        return out

In [3]:
class PixelNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.atblock_1 = AttentionBlock(in_channels=32)
        self.atblock_2 = AttentionBlock(in_channels=32)
        self.W = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(3),
        )
        nn.init.constant_(self.W[1].weight, 0)
        nn.init.constant_(self.W[1].bias, 0)

        

        
    def forward(self, t):
        t = self.conv1(t)

        t = self.conv2(t)
        
        t = self.conv3(t)
        
        t_1 = self.atblock_1(t)
        t_2 = self.atblock_2(t)
        t_ = torch.cat((t_1, t_2), 1)
        t = t + self.W(t_)

        return t