In [1]:
from d2l import torch as d2l
import torch
import torch.nn as nn
import torch.nn.functional as F

In [35]:
class Residual(nn.Module):
    def __init__(self, num_inputs, num_outputs, use_1x1_conv=False, strides=1):
        super().__init__()
        
        self.conv1 = nn.Conv2d(num_inputs, num_outputs, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_outputs, num_outputs, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(num_outputs)
        self.batchnorm2 = nn.BatchNorm2d(num_outputs)
        
        if (use_1x1_conv):
            self.conv3 = nn.Conv2d(num_inputs, num_outputs, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.relu = nn.ReLU()
    def forward(self, X):
        y = self.batchnorm2(self.conv2(self.relu(self.batchnorm1(self.conv1(X)))))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(y + X)

In [36]:
blk = Residual(3, 3, use_1x1_conv=True, strides=2)
x = torch.rand((4, 3, 10, 10))
blk(x).shape

torch.Size([4, 3, 5, 5])

In [34]:
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

In [40]:
def resnet_block(num_inputs, num_channels, num_residuals, first_block=False):
    blks = []
    for i in range(num_residuals):
        if (i == 0) and not first_block:
                blks.append(Residual(num_inputs, num_channels, use_1x1_conv=True))
        else:
            blks.append(Residual(num_channels, num_channels))
    return blks

In [41]:
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

In [45]:
net = nn.Sequential(
    b1, b2, b3, b4, b5, nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 10)
)

In [46]:
X = torch.rand(1, 1, 224, 224)
for layer in net:
    X = layer(X)
    print(f'{layer.__class__.__name__}, {X.shape}')

Sequential, torch.Size([1, 64, 56, 56])
Sequential, torch.Size([1, 64, 56, 56])
Sequential, torch.Size([1, 128, 56, 56])
Sequential, torch.Size([1, 256, 56, 56])
Sequential, torch.Size([1, 512, 56, 56])
AdaptiveAvgPool2d, torch.Size([1, 512, 1, 1])
Flatten, torch.Size([1, 512])
Linear, torch.Size([1, 10])
