# Import Libraries

In [47]:
import torch
import torch.nn as nn

# Channel Attention Module

In [48]:
class channel_attention_module(nn.Module):
  def __init__(self, ch, ratio=8):
    super().__init__()

    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.max_pool = nn.AdaptiveMaxPool2d(1)

    # Shared MLP
    self.mlp = nn.Sequential(
        nn.Linear(ch, ch // ratio, bias = False),
        nn.ReLU(inplace=True),
        nn.Linear(ch // ratio, ch, bias = False)
    )

    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
    x1 = self.mlp(x1)

    x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
    x2 = self.mlp(x2)

    # Add Average and Max Pooling
    concat = x1 + x2
    concat = self.sigmoid(concat).unsqueeze(-1).unsqueeze(-1)
    refined_feats = x * concat

    return refined_feats

# Spatial Attention Module

In [49]:
class spatial_attention_module(nn.Module):
  def __init__(self, kernel_size = 7):
    super().__init__()

    self.conv = nn.Conv2d(2, 1, kernel_size, padding = 3, bias = False)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    # Average Pooling
    x1 = torch.mean(x, dim = 1, keepdim = True)
    # Max Pooling
    x2, _ = torch.max(x, dim = 1, keepdim = True)

    # Concatenate Average and Max Pooling
    concat = torch.cat([x1, x2], dim = 1)

    concat= self.conv(concat)
    concat = self.sigmoid(concat)

    refined_feats = x * concat

    return refined_feats

# CBAM

In [50]:
class CBAM(nn.Module):
  def __init__(self, channel):
    super().__init__()

    self.channel_attention = channel_attention_module(channel)
    self.spatial_attention = spatial_attention_module()

  def forward(self, x):
    x = self.channel_attention(x)
    x = self.spatial_attention(x)

    return x