In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.utils.model_zoo as model_zoo


## DETR 编码环节

In [None]:


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)   # 沿着高度的有效像素和
        x_embed = not_mask.cumsum(2, dtype=torch.float32)   # 沿着宽度的有效像素和
        if self.normalize:  # 归一化机制，将每个相对位置限制到0~scale
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


In [None]:
# CBAM的通道注意力和空间注意力模块
class ChannelAttention(nn.Module):
    """通道注意力模块"""

    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    """空间注意力模块"""

    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class ResNetWithLayerCBAM(nn.Module):
    """ResNet，将CBAM放在每个Layer之后"""

    def __init__(self, block, layers, num_classes=1000):
        super().__init__()

        self.inplanes = 64

        # 初始层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 四个Layer
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.ca1 = ChannelAttention(256)   # Layer1输出通道数
        self.sa1 = SpatialAttention()

        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.ca2 = ChannelAttention(512)   # Layer2输出通道数
        self.sa2 = SpatialAttention()

        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.ca3 = ChannelAttention(1024)  # Layer3输出通道数
        self.sa3 = SpatialAttention()

        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.ca4 = ChannelAttention(2048)  # Layer4输出通道数
        self.sa4 = SpatialAttention()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def forward(self, x):
        # 初始层
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Layer1 + CBAM
        x = self.layer1(x)
        x = self.ca1(x) * x
        x = self.sa1(x) * x

        # Layer2 + CBAM
        x = self.layer2(x)
        x = self.ca2(x) * x
        x = self.sa2(x) * x

        # Layer3 + CBAM
        x = self.layer3(x)
        x = self.ca3(x) * x
        x = self.sa3(x) * x

        # Layer4 + CBAM
        x = self.layer4(x)
        x = self.ca4(x) * x
        x = self.sa4(x) * x

        # 全局池化和分类
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x