# Conv, Attention & CoAtNet

In Part 1, we delved into the foundational concepts of deep learning, covering Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs). We highlighted that while ViTs can outperform CNNs, they typically require large datasets and extensive computational resources. Now in this chapter let's explore how we can combine the strengths of convolution and attention mechanisms.

## Convolution and Attention: Strengths and Weaknesses

### Convolutional Neural Networks (CNNs)

**Strengths:**
- **Locality and Translation Invariance:** CNNs are designed to recognize patterns and features in local receptive fields, making them highly effective at capturing spatial hierarchies in images.
- **Parameter Efficiency:** Weight sharing in convolutions significantly reduces the number of parameters compared to fully connected layers.
- **Computational Efficiency:** CNNs are highly optimized for parallel computation on modern hardware, making them fast and efficient for image processing tasks.

**Weaknesses:**
- **Limited Global Context:** Due to their local receptive fields, CNNs may struggle to capture long-range dependencies and global context in images.
- **Fixed Architecture:** The predefined nature of convolutional kernels can limit the flexibility and adaptability of CNNs to various data distributions.

### Vision Transformers (ViTs)

**Strengths:**
- **Global Context Capture:** ViTs use self-attention mechanisms to model long-range dependencies and global relationships between different parts of an image.
- **Flexibility:** The attention mechanism allows ViTs to adapt to various data structures and distributions without being constrained by local receptive fields.

**Weaknesses:**
- **Data and Compute Intensive:** ViTs generally require large datasets and substantial computational resources for effective training.
- **Parameter Heavy:** The self-attention mechanism can lead to a large number of parameters, making ViTs less efficient in terms of memory usage compared to CNNs.
- **Generalization ability:** Although Transformer models have a large model capacity, their generalization ability may not be as good as convolutional networks due to the lack of appropriate induction bias.

## Combining Convolution and Attention: The Idea Behind CoAtNet

To leverage the strengths of both CNNs and ViTs, researchers have proposed hybrid architectures that integrate convolutional layers and attention mechanisms. One prominent example of such an architecture is CoAtNet (Convolution and Attention Network).

### CoAtNet: An Overview

CoAtNet is designed to combine the efficiency of CNNs with the global context modeling capability of attention mechanisms. It consists of several stages, each utilizing different combinations of convolution and attention blocks:

1. **Initial Convolutional Layers:** The initial stages use convolutional layers to efficiently capture local features and build a strong spatial hierarchy.
2. **Transition to Attention:** In the middle stages, CoAtNet gradually introduces attention mechanisms to capture global dependencies and relationships.
3. **Attention Dominated Stages:** The final stages rely more on self-attention to refine and enhance the global context understanding.

This progressive combination ensures that CoAtNet benefits from the parameter and computational efficiency of convolutions in the early stages, while also leveraging the powerful global feature extraction capabilities of attention mechanisms in the later stages.

### CoAtNet Architecture

Here is a simplified implementation of CoAtNet in PyTorch to give you an idea of how such a model can be constructed:

The same as ViT, the CoAtNet has a deepe structure.It means that If you want to see its performance, you need train it on a big dataset. 

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SqueezeExcite(nn.Module):
    def __init__(self, in_ch, reduction=4):
        super(SqueezeExcite, self).__init__()
        reduced_ch = in_ch // reduction
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_ch, reduced_ch, 1),
            nn.ReLU(),
            nn.Conv2d(reduced_ch, in_ch, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.se(x)

class MBConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, expand_ratio=4, reduction=4):
        super(MBConvBlock, self).__init__()
        mid_ch = in_ch * expand_ratio
        self.stride = stride
        self.expand = in_ch != out_ch
        self.expand_conv = nn.Conv2d(in_ch, mid_ch, 1, bias=False)
        self.bn0 = nn.BatchNorm2d(mid_ch)
        self.dw_conv = nn.Conv2d(mid_ch, mid_ch, kernel_size, stride, padding=(kernel_size-1)//2, groups=mid_ch, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.se = SqueezeExcite(mid_ch, reduction)
        self.project_conv = nn.Conv2d(mid_ch, out_ch, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        identity = x
        x = self.expand_conv(x)
        x = self.bn0(x)
        x = F.relu6(x)
        x = self.dw_conv(x)
        x = self.bn1(x)
        x = F.relu6(x)
        x = self.se(x)
        x = self.project_conv(x)
        x = self.bn2(x)
        if self.stride == 1 and self.expand:
            x += identity
        return x

class AttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1):
        super(AttentionBlock, self).__init__()
        self.mhsa = nn.MultiheadAttention(dim, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        b, c, h, w = x.shape
        x = x.view(b, c, h * w).permute(2, 0, 1)  # (seq_len, batch_size, dim)
        x = self.norm(x)
        x, _ = self.mhsa(x, x, x)
        x = x.permute(1, 2, 0).view(b, c, h, w)  # (batch_size, dim, height, width)
        return x

class CoAtNet(nn.Module):
    def __init__(self, image_size=224, in_channels=3, num_classes=1000):
        super(CoAtNet, self).__init__()
        self.s1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.s2 = nn.Sequential(
            MBConvBlock(64, 128, stride=2),
            MBConvBlock(128, 128, stride=1)
        )
        self.s3 = nn.Sequential(
            MBConvBlock(128, 256, stride=2),
            MBConvBlock(256, 256, stride=1)
        )
        self.s4 = nn.Sequential(
            MBConvBlock(256, 512, stride=2),
            MBConvBlock(512, 512, stride=1)
        )
        self.s5 = nn.Sequential(
            AttentionBlock(512, num_heads=8),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        x = self.s4(x)
        x = self.s5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Example usage:
model = CoAtNet(image_size=224, in_channels=3, num_classes=1000)
input_tensor = torch.randn(8, 3, 224, 224)  # Batch of 8, 3x224x224 images
output = model(input_tensor)
print(output.shape)  # Should output torch.Size([8, 1000])
print(model)

torch.Size([8, 1000])
CoAtNet(
  (s1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (s2): Sequential(
    (0): MBConvBlock(
      (expand_conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dw_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (se): SqueezeExcite(
        (se): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          (2): ReLU()
          (3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
    