In [2]:
import torch
import torch.nn as nn
import math

In [20]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # 利用1x1卷积代替全连接
        self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        print('x',x.shape)
        print('avg_pool', self.avg_pool(x).shape)
        print('max_pool', self.max_pool(x).shape)
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        print('out', out.shape)
        print('result', self.sigmoid(out).shape)
        return self.sigmoid(out)



In [21]:
x = torch.ones(2,512,28,28)
b,c,w,h = x.shape
channelAttention = ChannelAttention(c)
channelAttention(x)

x torch.Size([2, 512, 28, 28])
avg_pool torch.Size([2, 512, 1, 1])
max_pool torch.Size([2, 512, 1, 1])
out torch.Size([2, 512, 1, 1])
result torch.Size([2, 512, 1, 1])


tensor([[[[0.3882]],

         [[0.3906]],

         [[0.5325]],

         ...,

         [[0.5554]],

         [[0.5454]],

         [[0.4556]]],


        [[[0.3882]],

         [[0.3906]],

         [[0.5325]],

         ...,

         [[0.5554]],

         [[0.5454]],

         [[0.4556]]]], grad_fn=<SigmoidBackward0>)

In [30]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        print('x',x.shape)
        print('avg_pool', avg_out.shape)
        print('max_pool', avg_out.shape)
        x = torch.cat([avg_out, max_out], dim=1)
        print('cat_x', x.shape)
        x = self.conv1(x)
        print('cat_x', x.shape)
        return self.sigmoid(x)


In [31]:
x = torch.ones(2,512,28,28)
b,c,w,h = x.shape
spatialAttention = SpatialAttention()
spatialAttention(x)

x torch.Size([2, 512, 28, 28])
avg_pool torch.Size([2, 1, 28, 28])
max_pool torch.Size([2, 1, 28, 28])
cat_x torch.Size([2, 2, 28, 28])
cat_x torch.Size([2, 1, 28, 28])


tensor([[[[0.5344, 0.5581, 0.5547,  ..., 0.5775, 0.5661, 0.5444],
          [0.5531, 0.5956, 0.5937,  ..., 0.6075, 0.5751, 0.5679],
          [0.4813, 0.5607, 0.5571,  ..., 0.5881, 0.5915, 0.5816],
          ...,
          [0.4418, 0.5045, 0.5202,  ..., 0.6416, 0.6296, 0.6327],
          [0.4493, 0.5051, 0.5097,  ..., 0.6030, 0.6108, 0.5946],
          [0.4400, 0.4956, 0.5074,  ..., 0.6156, 0.6322, 0.6198]]],


        [[[0.5344, 0.5581, 0.5547,  ..., 0.5775, 0.5661, 0.5444],
          [0.5531, 0.5956, 0.5937,  ..., 0.6075, 0.5751, 0.5679],
          [0.4813, 0.5607, 0.5571,  ..., 0.5881, 0.5915, 0.5816],
          ...,
          [0.4418, 0.5045, 0.5202,  ..., 0.6416, 0.6296, 0.6327],
          [0.4493, 0.5051, 0.5097,  ..., 0.6030, 0.6108, 0.5946],
          [0.4400, 0.4956, 0.5074,  ..., 0.6156, 0.6322, 0.6198]]]],
       grad_fn=<SigmoidBackward0>)

In [51]:
class Cbam_block(nn.Module):
    def __init__(self, channel, ratio=8, kernel_size=7):
        super(Cbam_block, self).__init__()
        self.channelattention = ChannelAttention(channel, ratio=ratio)
        self.spatialattention = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        x = x * self.channelattention(x)
        print('xxx',x.shape)
        x = x * self.spatialattention(x)
        print('xxx2',x.shape)
        return x

In [52]:
x = torch.ones(2,512,28,28)
b,c,w,h = x.shape
cbam_block = Cbam_block(c)
cbam_block(x)

x torch.Size([2, 512, 28, 28])
avg_pool torch.Size([2, 512, 1, 1])
max_pool torch.Size([2, 512, 1, 1])
out torch.Size([2, 512, 1, 1])
result torch.Size([2, 512, 1, 1])
xxx torch.Size([2, 512, 28, 28])
x torch.Size([2, 512, 28, 28])
avg_pool torch.Size([2, 1, 28, 28])
max_pool torch.Size([2, 1, 28, 28])
cat_x torch.Size([2, 2, 28, 28])
cat_x torch.Size([2, 1, 28, 28])
xxx2 torch.Size([2, 512, 28, 28])


tensor([[[[0.2486, 0.2509, 0.2418,  ..., 0.2716, 0.2586, 0.2966],
          [0.2231, 0.2237, 0.2301,  ..., 0.2742, 0.2688, 0.3155],
          [0.2366, 0.2315, 0.2503,  ..., 0.2789, 0.2646, 0.3042],
          ...,
          [0.2475, 0.2372, 0.2675,  ..., 0.2878, 0.2696, 0.3045],
          [0.2793, 0.2612, 0.2954,  ..., 0.3058, 0.2786, 0.3024],
          [0.2898, 0.2695, 0.3027,  ..., 0.3031, 0.2861, 0.3004]],

         [[0.2135, 0.2154, 0.2076,  ..., 0.2332, 0.2220, 0.2547],
          [0.1915, 0.1921, 0.1975,  ..., 0.2354, 0.2308, 0.2709],
          [0.2032, 0.1987, 0.2150,  ..., 0.2394, 0.2272, 0.2612],
          ...,
          [0.2126, 0.2037, 0.2297,  ..., 0.2471, 0.2315, 0.2615],
          [0.2398, 0.2243, 0.2537,  ..., 0.2626, 0.2392, 0.2596],
          [0.2488, 0.2314, 0.2599,  ..., 0.2602, 0.2457, 0.2579]],

         [[0.2407, 0.2429, 0.2341,  ..., 0.2629, 0.2503, 0.2872],
          [0.2160, 0.2166, 0.2227,  ..., 0.2654, 0.2603, 0.3054],
          [0.2291, 0.2241, 0.2424,  ..., 0

In [47]:
a = torch.Tensor([[[[1,2,3],[1,2,3],[1,2,3]]]])
print(a.shape)
b = torch.Tensor([[[[2,4,2]]]])
print(b.shape)
a*b

torch.Size([1, 1, 3, 3])
torch.Size([1, 1, 1, 3])


tensor([[[[2., 8., 6.],
          [2., 8., 6.],
          [2., 8., 6.]]]])