In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

# Residual Blocks

Residual blocks are fundamental building blocks of modern deep neural networks, particularly in ResNet architectures. They address the problem of degrading accuracy in very deep networks through a clever "skip connection" mechanism.

Key components of a residual block:

1. **Skip Connection (Identity Path)**:
   - Directly passes input x to the output
   - Preserves original features
   - Helps with gradient flow during backpropagation

2. **Residual Path**:
   - Processes input through layers (typically conv + batch norm + ReLU)
   - Learns residual function F(x)
   - Output = F(x) + x (adds skip connection)

The key insight: Instead of learning a complete transformation H(x), the network learns the residual F(x) = H(x) - x. This makes it easier for the network to learn identity mappings when needed, allowing for successful training of very deep networks.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        # First convolutional layer
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, 
                              padding=1, stride=strides)
        self.bn1 = nn.BatchNorm2d(num_channels)
        
        # Second convolutional layer
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_channels)
        
        # 1x1 convolution for dimension matching if needed
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, 
                                 kernel_size=1, stride=strides)
        else:
            self.conv3 = None
            
    def forward(self, X):
        # Main path through the residual block
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        
        # Identity path (skip connection)
        if self.conv3:
            X = self.conv3(X)
        
        # Add skip connection to main path
        Y += X
        return F.relu(Y)  # Final activation after addition

In [None]:
# Create a sample input tensor
X = torch.randn(1, 3, 224, 224)  # Batch_size=1, channels=3, height=224, width=224

# Create a residual block that maintains input dimensions
blk = ResidualBlock(3, 3)
Y1 = blk(X)
print('Output shape with regular residual block:', Y1.shape)

# Create a residual block that changes channels and reduces spatial dimensions
blk = ResidualBlock(3, 6, use_1x1conv=True, strides=2)
Y2 = blk(X)
print('Output shape with strided residual block:', Y2.shape)