# CA

## 简述

CA注意力机制是2017年提出的一种注意力机制，全面关注特征层的空间信息和通道信息

该文章的作者认为现有的注意力机制在求取通道注意力的时候，**通道的处理一般是采用全局最大池化/平均池化，这样会损失掉物体的空间信息**。作者期望在引入通道注意力机制的同时，引入空间注意力机制，作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示，可以认为分为`两个并行阶段`：

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化，分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[C, H, W]，在经过宽方向的平均池化后，获得的特征层`shape为[C, H, 1]`，此时我们将特征映射到了高维度上；在经过高方向的平均池化后，获得的特征层`shape为[C, 1, W]`，此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并，**将宽和高转置到同一个维度，然后进行堆叠**，将宽高特征合并在一起，此时我们获得的特征层为：[C, 1, H+W]，利用`卷积+标准化+激活函数获得特征`。

之后再次分开为两个并行阶段，再将宽高分开成为[C, 1, H]和[C, 1, W]，之后进行转置。获得两个特征层[C, H, 1]和[C, 1, W]。

然后利用`1x1卷积调整通道数`后取sigmoid获得宽高维度上的注意力情况。乘上原有的特征就是CA注意力机制。

简而言之，通道注意力计算为:

```{tip}
$$\begin{aligned}\mathbf{M_c(F)}&=\sigma(\mathrm{MLP}(\mathrm{AvgPool}(\mathbf{F}))+\mathrm{MLP}(\mathrm{MaxPool}(\mathbf{F})))\\\mathbf{M_c(F)}&=\sigma(\mathbf{W_1}(\mathbf{W_0}(\mathbf{F}_{avg}^c))+\mathbf{W_1}(\mathbf{W_0}(\mathbf{F}_{max}^c)))\end{aligned}$$
```

sigma是sigmoid函数 ，并且 $\mathbf{W}_0\in\mathbb{R}^{C/r\times C}\text{, }\mathbf{W}_1\in\mathbb{R}^{C\times C/r}$
 ，ReLU跟在W0后使用。

## 流程图

```{figure} ../images/attention/channel-attention.png
:width: 400px
:align: center

Channel attention.
```

[可以参考这个链接](https://paperswithcode.com/method/channel-attention-module)

## 官方代码

In [1]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out