<a href="https://colab.research.google.com/github/BhardwajArjit/Research-Paper-Replication/blob/main/CBAM_Replication.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook replicates the research paper titled "**CBAM: Convolutional Block Attention Module**" with PyTorch.

The link to paper: https://arxiv.org/abs/1807.06521

CBAM (Convolutional Block Attention Module) aims to enhance the feature representation of convolutional neural networks by incorporating channel-wise and spatial-wise attention mechanisms.

The channel module focuses on "what" is meaningful in the given input image whereas the spatial module focuses on "where" the meaningful features are in the image.

## 0. Get setup

In [5]:
import torch
from torch import nn

## 1. Channel Attention Module

In [6]:
class channel_attention_module(nn.Module):
  """
  Channel Attention Module.

  Parameters:
  - channel (int): Number of input channels.
  - ratio (int, optional): Reduction ratio for the intermediate channels. Default is 8.
  """
  def __init__(self, channel, ratio=8):
    super().__init__()

    # Adaptive average pooling layer to capture global average information
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    # Adaptive max pooling layer to capture global max information
    self.max_pool = nn.AdaptiveMaxPool2d(1)
    # Multi-layer perceptron (MLP) for modeling channel dependencies
    self.mlp = nn.Sequential(
        nn.Linear(in_features=channel,
                  out_features=channel // ratio,
                  bias=False),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=channel // ratio,
                  out_features=channel,
                  bias=False)
    )
    # Sigmoid activation function to produce attention weights between 0 and 1
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
    x1 = self.mlp(x1)
    x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
    x2 = self.mlp(x2)
    # Element-wise addition of the MLP-transformed average and max pooling results
    feats = x1 + x2
    # Apply sigmoid activation to produce attention weights, reshape to (batch_size, channels, 1, 1)
    feats = self.sigmoid(feats).unsqueeze(-1).unsqueeze(-1)
    # Multiply the output of the channel attention module with input features
    refined_features = x * feats

    return refined_features

In [7]:
# Checking the results of channel attention module
x = torch.randn((8, 32, 128, 128))
module_1 = channel_attention_module(32)
y = module_1(x)
print(y.shape)

torch.Size([8, 32, 128, 128])


## 2. Spatial Attention Module

In [8]:
class spatial_attention_module(nn.Module):
  """
  Spatial Attention Module.

  Parameters:
  - kernel_size (int, optional): Size of the convolutional kernel. Default is 7.
  """
  def __init__(self, kernel_size=7):
    super().__init__()

    # Convolutional layer for generating spatial attention map
    self.conv = nn.Conv2d(in_channels=2,
                          out_channels=1,
                          kernel_size=kernel_size,
                          padding=3,
                          bias=False)
    # Sigmoid activation function to produce attention weights between 0 and 1
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    # Compute the mean along the channel dimension
    x1 = torch.mean(x, dim=1, keepdim=True)
    # Compute the maximum along the channel dimension
    x2, _ = torch.max(x, dim=1, keepdim=True)
    # concatenate x1 and x2 to generate more efficient feature descriptor
    feats = torch.cat([x1, x2], dim=1)
    # Pass the concatenated features through the convolutional layer to compute the 2D spatial attention map
    feats = self.conv(feats)
    # Apply sigmoid activation to produce attention weights
    feats = self.sigmoid(feats)
    # Multiply the output of the spatial attention module with input features
    refined_features = x * feats
    return refined_features

In [9]:
# checking the results of spatial attention module
x = torch.randn((8, 32, 128, 128))
module_2 = spatial_attention_module()
y = module_2(x)
print(y.shape)

torch.Size([8, 32, 128, 128])


## 3. CBAM (Convolutional Block Attention Module)

In [10]:
class CBAM(nn.Module):
  """
  Convolutional Block Attention Module (CBAM).

  Parameters:
  - channel (int): Number of input channels.
  """
  def __init__(self, channel):
    super().__init__()
    # Channel Attention Module
    self.channel_layer = channel_attention_module(channel)
    # Spatial Attention Module
    self.spatial_layer = spatial_attention_module()

  def forward(self, x):
    # the channel and spatial modules are arranged in parallel manner
    x = self.channel_layer(x)
    x = self.spatial_layer(x)

    return x

In [11]:
# checking the results of cbam
x = torch.randn((8, 32, 128, 128))
cbam = CBAM(32)
y = cbam(x)
print(y.shape)

torch.Size([8, 32, 128, 128])
