<a href="https://colab.research.google.com/github/BhardwajArjit/Research-Paper-Replication/blob/main/CBAM_Replication.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook replicates the research paper titled "**CBAM: Convolutional Block Attention Model**" with PyTorch.

The link to paper: https://arxiv.org/abs/1807.06521

CBAM (Convolutional Block Attention Module) aims to enhance the feature representation of convolutional neural networks by incorporating channel-wise and spatial-wise attention mechanisms.

## 0. Get setup

In [None]:
import torch
from torch import nn

## 1. Channel Attention Module

In [None]:
class channel_attention_module(nn.Module):
  def __init__(self, ch, ratio=8):
    super().__init__()

    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.max_pool = nn.AdaptiveMaxPool2d(1)

    self.mlp = nn.Sequential(
        nn.Linear(ch, ch // ratio, bias=False),
        nn.ReLU(inplace=True),
        nn.Linear(ch // ratio, ch, bias=False)
    )

    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x_1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
    x_1 = self.mlp(x_1)
    x_2 = self.max_pool(x).squeeze(-1).squeeze(-1)
    x_2 = self.mlp(x_2)

    feats = x_1 + x_2
    feats = self.sigmoid(feats).unsqueeze(-1).unsqueeze(-1)

    refined_features = x * feats

    return refined_features

    print(x_1.shape, x_2.shape)

In [None]:
# Checking the results of channel attention module
x = torch.randn((8, 32, 128, 128))
module_1 = channel_attention_module(32)
y = module_1(x)
print(y.shape)

torch.Size([8, 32, 128, 128])


## 2. Spatial Attention Module

In [None]:
class spatial_attention_module(nn.Module):
  def __init__(self, kernel_size=7):
    super().__init__()

    self.conv = nn.Conv2d(2, 1, kernel_size, padding=3, bias=False)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x_1 = torch.mean(x, dim=1, keepdim=True)
    x_2, _ = torch.max(x, dim=1, keepdim=True)

    feats = torch.cat([x_1, x_2], dim=1)
    feats = self.conv(feats)
    feats = self.sigmoid(feats)

    refined_features = x * feats
    return refined_features

In [None]:
# checking the results of spatial attention module
x = torch.randn((8, 32, 128, 128))
module_2 = spatial_attention_module()
y = module_2(x)
print(y.shape)

torch.Size([8, 32, 128, 128])


## 3. CBAM (Convolutional Block Attention Module)

In [None]:
class CBAM(nn.Module):
  def __init__(self, channel):
    super().__init__()

    self.channel_layer = channel_attention_module(channel)
    self.spatial_layer = spatial_attention_module()

  def forward(self, x):
    x = self.channel_layer(x)
    x = self.spatial_layer(x)

    return x

In [None]:
# checking the results of cbam
x = torch.randn((8, 32, 128, 128))
model_1 = CBAM(32)
y = model_1(x)
print(y.shape)

torch.Size([8, 32, 128, 128])
