In [1]:
# Original Source https://www.youtube.com/watch?v=vRYM1KdFtnk
# Arxiv link: https://arxiv.org/abs/1807.06521

In [2]:
import torch 
import torch.nn as nn

In [3]:
class channel_attention_module(nn.Module):
    def __init__(self, channels, ratio = 8) -> None:
        super().__init__()
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.mlp = nn.Sequential(nn.Linear(channels, channels//ratio, bias=False),
                                 nn.ReLU(inplace = True),
                                 nn.Linear(channels//ratio, channels, bias = False)
                                 )
        
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
        x1 = self.mlp(x1)
        
        x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
        x2 = self.mlp(x2)
        
        feats = x1 + x2 
        
        feats = self.sigmoid(feats).unsqueeze(-1).unsqueeze(-1)
        
        refined_feats = x * feats
        return refined_feats

In [22]:
# My adaption to align with original implementation
class spatial_attention_module(nn.Module):
    def __init__(self, channels,  kernel_size = 7) -> None:
        super().__init__()
        
        self.conv = nn.Conv2d(channels, 1, kernel_size, padding=3, bias = False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        #x1 = torch.mean(x, dim=1, keepdim=True)
        #x2, _  = torch.max(x, dim=1, keepdim=True)
        #print(x.shape)
        #feats  = torch.cat([x1, x2], dim= 1)
        #feats = self.conv(feats)
        feats = self.conv(x)
        feats = self.sigmoid(feats)
        refined_feats = x * feats 
        return refined_feats

In [23]:
class cbam(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        self.ca = channel_attention_module(channels)
        self.sa = spatial_attention_module(channels)
        
    def forward(self, x):
        x = self.ca(x)
        x = self.sa(x)
        return x 
        

In [24]:
if __name__ == "__main__":
    x = torch.randn((8, 32, 512, 512))
    channels = x.shape[1]
    #module = channel_attention_module(32)
    #module = spatial_attention_module()
    module = cbam(channels)
    y = module(x)
    print(y.shape)

torch.Size([8, 32, 512, 512])
torch.Size([8, 32, 512, 512])
