# 8.6. Residual Networks (ResNet) and ResNeXt

In [9]:
import torch
from torch import nn
from torch.nn import functional as F


## 8.6.2. Residual Blocks

https://d2l.ai/_images/residual-block.svg

In [2]:
class Residual(nn.Module):  #@save
    """The Residual block of ResNet models."""
    def __init__(self, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.LazyBatchNorm2d()
        self.bn2 = nn.LazyBatchNorm2d()

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

In [12]:
blk = Residual(3)
blk

Residual(
  (conv1): LazyConv2d(0, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): LazyConv2d(0, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): LazyBatchNorm2d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): LazyBatchNorm2d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [None]:
X = torch.rand(4, 3, 6, 6)
blk(X).shape

In [7]:
print(blk(X))

tensor([[[[0.4872, 1.0892, 0.5275],
          [0.4104, 0.1215, 0.0000],
          [0.0000, 2.4253, 0.0000]],

         [[0.0000, 1.9172, 0.0000],
          [0.0000, 0.7627, 0.0000],
          [1.1539, 0.5950, 0.0000]],

         [[1.0890, 0.0000, 0.0838],
          [0.0000, 0.0000, 0.1193],
          [0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.6778],
          [0.0000, 0.5822, 0.1278],
          [0.0000, 0.9163, 1.1069]],

         [[0.5836, 0.1874, 1.8602],
          [1.2932, 1.7341, 2.3614],
          [0.0000, 0.8420, 1.1317]],

         [[0.0000, 2.2291, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.2304]]],


        [[[0.1678, 1.4587, 0.4502],
          [0.1689, 1.0954, 0.0000],
          [0.0000, 0.0000, 1.5281]],

         [[0.0000, 0.0000, 0.0616],
          [1.3951, 0.0000, 0.0000],
          [0.0806, 0.0000, 0.0000]],

         [[0.0000, 1.1774, 0.0000],
          [0.1377, 1.8422, 0.0000],
          [0.1729, 0.5947, 0.0355]],

        

In [4]:
blk = Residual(6, use_1x1conv=True, strides=2)
blk(X).shape

torch.Size([4, 6, 3, 3])