In [1]:
import torch
import torch.nn as nn

In [2]:
class ChannelAttention(nn.Module):
    """Channel attention module for CBAM"""
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1) # Both of these pooling returns 1 value per channel. [C,1,1]
        self.max_pool = nn.AdaptiveMaxPool2d(1) # [C,1,1]

        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels//reduction, 1, bias=False), # This mixes channel info, not spatial since the kernel size is 1 [Cr, 1, 1]
            nn.ReLU(inplace=True), # Bottleneck
            nn.Conv2d(in_channels//reduction, inchannels, 1, bias=False) # [C, 1, 1]
        )
        self.sigmoid = nn.Sigmoid() # Scale 0-1 per channel

    def forward(self, x):
        max_pool = self.max_pool(x)
        avg_pool = self.avg_pool(x)
        max_out = self.fc(max_pool)
        avg_out = self.fc(avg_pool)
        out = max_out + avg_out
        return self.sigmoid(out)

![image.png](attachment:07cf689b-b97f-47f8-8924-a9a9f47ff931.png)

In [3]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size = 7):
        super(SpatialAttention, self).__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) # [1, H, W]
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        max_out, _ = torch.max(x, dim=1, keepdim=True) # [1, H, W]
        avg_out = torch.mean(x, dim=1, keepdim=True) # [1, H, W]
        out = torch.cat([avg_out, max_out], dim=1) # [2, H, W]
        out = self.conv(out)
        return self.sigmoid(out)

![image.png](attachment:4069fa1e-892b-49ca-91c0-e5bb5cef243c.png)  
Here Mc and Ms are Channel and Spatial Attention maps
![image.png](attachment:7fd56f7c-9cda-4141-861b-8be28cbb9150.png)

In [4]:
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.channel_attention(x) # Channel attention and Spatial attention are applied sequentially
        x = x * self.spatial_attention(x)
        return x