<a href="https://colab.research.google.com/github/BhardwajArjit/Research-Paper-Replication/blob/main/CBAM_Replication.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook replicates the research paper titled "**CBAM: Convolutional Block Attention Module**" with PyTorch.

The link to paper: https://arxiv.org/abs/1807.06521

CBAM (Convolutional Block Attention Module) aims to enhance the feature representation of convolutional neural networks by incorporating channel-wise and spatial-wise attention mechanisms.

The channel module focuses on "what" is meaningful in the given input image whereas the spatial module focuses on "where" the meaningful features are in the image.

## 0. Get setup

In [1]:
import torch
from torch import nn

## 1. Channel Attention Module

In [2]:
class channel_attention_module(nn.Module):
  def __init__(self, channel, ratio=8):
    super().__init__()

    # defining average pooling layer
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    # defining max pooling layer
    self.max_pool = nn.AdaptiveMaxPool2d(1)
    # defining multi-layer perceptron
    self.mlp = nn.Sequential(
        nn.Linear(channel, channel // ratio, bias=False),
        nn.ReLU(inplace=True),
        nn.Linear(channel // ratio, channel, bias=False)
    )
    # defining the sigmoid function
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
    x1 = self.mlp(x1)
    x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
    x2 = self.mlp(x2)

    feats = x1 + x2
    feats = self.sigmoid(feats).unsqueeze(-1).unsqueeze(-1)
    # multiplying the output of channel attention module with input features
    refined_features = x * feats

    return refined_features

In [3]:
# Checking the results of channel attention module
x = torch.randn((8, 32, 128, 128))
module_1 = channel_attention_module(32)
y = module_1(x)
print(y.shape)

torch.Size([8, 32, 128, 128])


## 2. Spatial Attention Module

In [4]:
class spatial_attention_module(nn.Module):
  def __init__(self, kernel_size=7):
    super().__init__()

    # defining the convolutional layer
    self.conv = nn.Conv2d(2, 1, kernel_size, padding=3, bias=False)
    # defining the sigmoid function
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x1 = torch.mean(x, dim=1, keepdim=True)
    x2, _ = torch.max(x, dim=1, keepdim=True)
    # concatenate x1 and x2 to generate more efficient feature descriptor
    feats = torch.cat([x1, x2], dim=1)
    # passing the features to compute 2d spatial attention map
    feats = self.conv(feats)
    feats = self.sigmoid(feats)

    refined_features = x * feats
    return refined_features

In [5]:
# checking the results of spatial attention module
x = torch.randn((8, 32, 128, 128))
module_2 = spatial_attention_module()
y = module_2(x)
print(y.shape)

torch.Size([8, 32, 128, 128])


## 3. CBAM (Convolutional Block Attention Module)

In [6]:
class CBAM(nn.Module):
  def __init__(self, channel):
    super().__init__()

    self.channel_layer = channel_attention_module(channel)
    self.spatial_layer = spatial_attention_module()

  def forward(self, x):
    # the channel and spatial modules are arranged in parallel manner
    x = self.channel_layer(x)
    x = self.spatial_layer(x)

    return x

In [7]:
# checking the results of cbam
x = torch.randn((8, 32, 128, 128))
cbam = CBAM(32)
y = cbam(x)
print(y.shape)

torch.Size([8, 32, 128, 128])
