In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

Vanish gradient problem: as neural networks get deeper and are trained on more complex data, researchers found that several gradients would vanish as they propogated over many layers during training.
--> This a result of direct mapping, where the network must calculate the complete transformation in each layer from input to output. This follows the intuition that each step must make measureable progress to the outcome

Solution came with Residual Mapping where instead of learning the transformation to go from the input to output, the system only learns that value needed to be added
- Traditional system learns H(x)
- Residual system learns U(x) where U(x) = H(x) - x which is the residual
The convolutional layers calculates H(x) then learns U(x) before adding back x to get the output and go to the next calculation

# Residual Blocks

Residual blocks are fundamental building blocks of modern deep neural networks, particularly in ResNet architectures. They address the problem of degrading accuracy in very deep networks through a clever "skip connection" mechanism.

Key components of a residual block:

1. **Skip Connection (Identity Path)**:
   - Directly passes input x to the output
   - Preserves original features
   - Helps with gradient flow during backpropagation

2. **Residual Path**:
   - Processes input through layers (typically conv + batch norm + ReLU)
   - Learns residual function F(x)
   - Output = F(x) + x (adds skip connection)

The key insight: Instead of learning a complete transformation H(x), the network learns the residual F(x) = H(x) - x. This makes it easier for the network to learn identity mappings when needed, allowing for successful training of very deep networks.

In [3]:
class Residual(nn.Module):  #@save
    """The Residual block of ResNet models."""
    def __init__(self, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1,
                                   stride=strides)
        self.conv2 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1,
                                       stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.LazyBatchNorm2d()
        self.bn2 = nn.LazyBatchNorm2d()

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

In [4]:
blk = Residual(3)
X = torch.randn(4, 3, 6, 6)
blk(X).shape

torch.Size([4, 3, 6, 6])

In [5]:
blk = Residual(6, use_1x1conv=True, strides=2)
blk(X).shape

torch.Size([4, 6, 3, 3])