<a href="https://colab.research.google.com/github/BhardwajArjit/Research-Paper-Replication/blob/main/ResNet%2BCBAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook implements ResNet with CBAM.

ResNet (Residual Network) is a deep neural network architecture that uses skip connections to facilitate training of very deep convolutional neural networks.

CBAM (Convolutional Block Attention Module) aims to enhance the feature representation of convolutional neural networks by incorporating channel-wise and spatial-wise attention mechanisms.

## 0. Get setup

In [2]:
import torch
import torch.nn as nn
import math
from torchsummary import summary
import torch.utils.model_zoo as model_zoo

## 1. Get ResNet models

In [3]:
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

## 2. ResNet + CBAM

In [4]:
def conv3x3(in_planes, out_planes, stride=1):
  """
  3x3 convolution with padding.

  Parameters:
  - in_planes (int): Number of input channels.
  - out_planes (int): Number of output channels.
  - stride (int, optional): Stride for the convolution operation. Default is 1.

  Returns:
  - nn.Conv2d: 3x3 convolutional layer with specified parameters.
  """
  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                    padding=1, bias=False)

In [5]:
class ChannelAttention(nn.Module):
  """
  Channel Attention Module.

  Parameters:
  - in_planes (int): Number of input channels.
  - ratio (int, optional): Reduction ratio for the intermediate channels. Default is 16.
  """
  def __init__(self, in_planes, ratio=16):
    super(ChannelAttention, self).__init__()
    # Adaptive average pooling layer to capture global average information
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    # Adaptive max pooling layer to capture global max information
    self.max_pool = nn.AdaptiveMaxPool2d(1)
    # First convolutional layer for dimension reduction
    self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
    # ReLU activation function
    self.relu1 = nn.ReLU()
    # Second convolutional layer for dimension restoration
    self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
    # Sigmoid activation function to produce attention weights between 0 and 1
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
    max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
    # Element-wise addition of the average and max pooling results
    out = avg_out + max_out
    # Apply the sigmoid activation to produce attention weights
    return self.sigmoid(out)

In [6]:
class SpatialAttention(nn.Module):
  """
  Spatial Attention Module.

  Parameters:
  - kernel_size (int, optional): Size of the convolutional kernel. Must be 3 or 7. Default is 7.
  """
  def __init__(self, kernel_size=7):
    super(SpatialAttention, self).__init__()

    assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
    padding = 3 if kernel_size == 7 else 1
    # Convolutional layer for generating spatial attention map
    self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
    # Sigmoid activation function to produce attention weights between 0 and 1
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    # Compute the mean along the channel dimension
    avg_out = torch.mean(x, dim=1, keepdim=True)
    # Compute the maximum along the channel dimension
    max_out, _ = torch.max(x, dim=1, keepdim=True)
    x = torch.cat([avg_out, max_out], dim=1)
    # Pass the concatenated features through the convolutional layer to compute the spatial attention map
    x = self.conv1(x)
    # Apply sigmoid activation to produce attention weights
    return self.sigmoid(x)

In [7]:
class BasicBlock(nn.Module):
  """
  Basic Residual Block.

  Parameters:
  - inplanes (int): Number of input channels.
  - planes (int): Number of output channels.
  - stride (int, optional): Stride for the convolution operation. Default is 1.
  - downsample (nn.Module, optional): Downsample layer to match dimensions. Default is None.
  """
  expansion = 1

  def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(BasicBlock, self).__init__()

    # First convolutional layer with batch normalization and ReLU activation
    self.conv1 = conv3x3(inplanes, planes, stride)
    self.bn1 = nn.BatchNorm2d(planes)
    self.relu = nn.ReLU(inplace=True)

    # Second convolutional layer with batch normalization
    self.conv2 = conv3x3(planes, planes)
    self.bn2 = nn.BatchNorm2d(planes)

    # Channel attention and spatial attention modules
    self.ca = ChannelAttention(planes)
    self.sa = SpatialAttention()

    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    """
    Forward pass of the BasicBlock.

    Parameters:
    - x (torch.Tensor): Input tensor.

    Returns:
    - torch.Tensor: Output tensor after applying the BasicBlock.
    """
    # Save the input tensor for the residual connection
    residual = x

    # First convolutional block
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    # Second convolutional block
    out = self.conv2(out)
    out = self.bn2(out)

    # Apply channel attention and spatial attention
    out = self.ca(out) * out
    out = self.sa(out) * out

    # If downsample is provided, apply it to the residual
    if self.downsample is not None:
        residual = self.downsample(x)

    # Add the residual connection
    out += residual
    out = self.relu(out)

    return out

In [8]:
class Bottleneck(nn.Module):
  """
  Bottleneck Residual Block.

  Parameters:
  - inplanes (int): Number of input channels.
  - planes (int): Number of output channels.
  - stride (int, optional): Stride for the convolution operation. Default is 1.
  - downsample (nn.Module, optional): Downsample layer to match dimensions. Default is None.
  """
  expansion = 4

  def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(Bottleneck, self).__init__()

    # First 1x1 convolutional layer
    self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)

    # Second 3x3 convolutional layer with stride
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                            padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    # Third 1x1 convolutional layer with increased output channels
    self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
    self.bn3 = nn.BatchNorm2d(planes * 4)

    # ReLU activation function
    self.relu = nn.ReLU(inplace=True)

    # Channel attention and spatial attention modules
    self.ca = ChannelAttention(planes * 4)
    self.sa = SpatialAttention()

    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    """
    Forward pass of the Bottleneck.

    Parameters:
    - x (torch.Tensor): Input tensor.

    Returns:
    - torch.Tensor: Output tensor after applying the Bottleneck.
    """
    # Save the input tensor for the residual connection
    residual = x

    # First 1x1 convolutional block
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    # Second 1x1 convolutional block
    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    # Third 1x1 convolutional block
    out = self.conv3(out)
    out = self.bn3(out)

    # Apply channel attention and spatial attention
    out = self.ca(out) * out
    out = self.sa(out) * out

    # If downsample is provided, apply it to the residual
    if self.downsample is not None:
        residual = self.downsample(x)

    # Add the residual connection
    out += residual
    out = self.relu(out)

    return out

In [9]:
class ResNet(nn.Module):
  """
  ResNet model.

  Parameters:
  - block (nn.Module): Residual block class (e.g., BasicBlock or Bottleneck).
  - layers (list): List of integers indicating the number of blocks in each layer.
  - num_classes (int, optional): Number of output classes. Default is 1000.
  """

  def __init__(self, block, layers, num_classes=1000):
    self.inplanes = 64
    super(ResNet, self).__init__()

    # Initial convolutional layer
    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                            bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    # Residual layers
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

    # Global average pooling and fully connected layer
    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(512 * block.expansion, num_classes)

    # Initialization of weights and batch normalization parameters
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()

  def _make_layer(self, block, planes, blocks, stride=1):
    """
    Helper function to create a residual layer.

    Parameters:
    - block (nn.Module): Residual block class (e.g., BasicBlock or Bottleneck).
    - planes (int): Number of output channels in each block.
    - blocks (int): Number of blocks in the layer.
    - stride (int, optional): Stride for the first block. Default is 1.

    Returns:
    - nn.Sequential: Residual layer.
    """
    downsample = None
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion),
        )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample))
    self.inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(block(self.inplanes, planes))

    return nn.Sequential(*layers)

  def forward(self, x):
    """
    Forward pass of the ResNet model.

    Parameters:
    - x (torch.Tensor): Input tensor.

    Returns:
    - torch.Tensor: Output tensor.
    """
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    return x

In [10]:
def resnet18_cbam(pretrained=False, **kwargs):
  """
  Constructs a ResNet-18 model.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  if pretrained:
      pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
      now_state_dict        = model.state_dict()
      now_state_dict.update(pretrained_state_dict)
      model.load_state_dict(now_state_dict)
  return model

In [11]:
def resnet34_cbam(pretrained=False, **kwargs):
  """
  Constructs a ResNet-34 model.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
  if pretrained:
      pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
      now_state_dict        = model.state_dict()
      now_state_dict.update(pretrained_state_dict)
      model.load_state_dict(now_state_dict)
  return model

In [12]:
def resnet50_cbam(pretrained=False, **kwargs):
  """
  Constructs a ResNet-50 model.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
  if pretrained:
      pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
      now_state_dict        = model.state_dict()
      now_state_dict.update(pretrained_state_dict)
      model.load_state_dict(now_state_dict)
  return model

In [13]:
def resnet101_cbam(pretrained=False, **kwargs):
  """
  Constructs a ResNet-101 model.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
  if pretrained:
      pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
      now_state_dict        = model.state_dict()
      now_state_dict.update(pretrained_state_dict)
      model.load_state_dict(now_state_dict)
  return model

In [14]:
def resnet152_cbam(pretrained=False, **kwargs):
  """
  Constructs a ResNet-152 model.

  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
  if pretrained:
      pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
      now_state_dict        = model.state_dict()
      now_state_dict.update(pretrained_state_dict)
      model.load_state_dict(now_state_dict)
  return model

In [15]:
resnet18 = resnet18_cbam()

In [16]:
summary(resnet18, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
AdaptiveAvgPool2d-10             [-1, 64, 1, 1]               0
           Conv2d-11              [-1, 4, 1, 1]             256
             ReLU-12              [-1, 4, 1, 1]               0
           Conv2d-13             [-1, 64, 1, 1]             256
AdaptiveMaxPool2d-14             [-1, 6