### CBAM网络

In [5]:
import torch
from torch import nn
from torchsummary import summary

#### 通道注意力机制

In [2]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

In [4]:
ChannelAttention_model = ChannelAttention(128)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
netG_A2B = ChannelAttention_model.to(device)
summary(netG_A2B, input_size=(128, 24, 24))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
 AdaptiveAvgPool2d-1            [-1, 128, 1, 1]               0
            Conv2d-2              [-1, 8, 1, 1]           1,024
              ReLU-3              [-1, 8, 1, 1]               0
            Conv2d-4            [-1, 128, 1, 1]           1,024
 AdaptiveMaxPool2d-5            [-1, 128, 1, 1]               0
            Conv2d-6              [-1, 8, 1, 1]           1,024
              ReLU-7              [-1, 8, 1, 1]               0
            Conv2d-8            [-1, 128, 1, 1]           1,024
           Sigmoid-9            [-1, 128, 1, 1]               0
Total params: 4,096
Trainable params: 4,096
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.28
Forward/backward pass size (MB): 0.01
Params size (MB): 0.02
Estimated Total Size (MB): 0.30
-----------------------------------------------

#### 空间注意力机制

In [9]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        
        # 这种padding方式是为了保证特征图大小不变
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        
        # 输出为单通道
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

In [10]:
SpatialAttention_model = SpatialAttention()

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
netG_A2B = SpatialAttention_model.to(device)
summary(netG_A2B, input_size=(2, 24, 24))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 1, 24, 24]              98
           Sigmoid-2            [-1, 1, 24, 24]               0
Total params: 98
Trainable params: 98
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.00
Estimated Total Size (MB): 0.01
----------------------------------------------------------------


### 在ResNet网络中添加注意力机制

In [12]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.ca(out) * out  # 广播机制
        out = self.sa(out) * out  # 广播机制

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

![](./attention.jpg)
![](./CBAM.jpg)