In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, filters):
        super(ResidualBlock, self).__init__()
        f1, f2, f3 = filters
        
        self.conv1 = nn.Conv2d(in_channels, f1, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(f1)
        
        self.conv2 = nn.Conv2d(f1, f2, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(f2)
        
        self.conv3 = nn.Conv2d(f2, f3, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(f3)
        
    def forward(self, x):
        shortcut = x
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        
        x += shortcut
        return F.relu(x)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, filters, stride=2):
        super(ConvBlock, self).__init__()
        f1, f2, f3 = filters
        
        self.conv1 = nn.Conv2d(in_channels, f1, kernel_size=1, stride=stride, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(f1)
        
        self.conv2 = nn.Conv2d(f1, f2, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(f2)
        
        self.conv3 = nn.Conv2d(f2, f3, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(f3)
        
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels, f3, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(f3)
        )
        
    def forward(self, x):
        shortcut = self.shortcut(x)
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        
        x += shortcut
        return F.relu(x)


class ResNet(nn.Module):
    def __init__(self, layers, num_classes):
        super(ResNet, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.stage2 = self._make_layer(64, [64, 64, 256], layers[0], stride=1)
        self.stage3 = self._make_layer(256, [128, 128, 512], layers[1], stride=2)
        self.stage4 = self._make_layer(512, [256, 256, 1024], layers[2], stride=2)
        self.stage5 = self._make_layer(1024, [512, 512, 2048], layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, num_classes)

    def _make_layer(self, in_channels, filters, blocks, stride):
        layers = []
        layers.append(ConvBlock(in_channels, filters, stride))

        for _ in range(1, blocks):
            layers.append(ResidualBlock(filters[2], filters))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

resnet_configs = {
    'ResNet18': [2, 2, 2, 2],
    'ResNet34': [3, 4, 6, 3],
    'ResNet50': [3, 4, 6, 3],
    'ResNet101': [3, 4, 23, 3],
    'ResNet152': [3, 8, 36, 3]
}

def ResNet18(num_classes=1000):
    return ResNet(resnet_configs['ResNet18'], num_classes)

def ResNet34(num_classes=1000):
    return ResNet(resnet_configs['ResNet34'], num_classes)

def ResNet50(num_classes=1000):
    return ResNet(resnet_configs['ResNet50'], num_classes)

def ResNet101(num_classes=1000):
    return ResNet(resnet_configs['ResNet101'], num_classes)

def ResNet152(num_classes=1000):
    return ResNet(resnet_configs['ResNet152'], num_classes)

def test_resnet():
    model = ResNet101(num_classes=6)
    x = torch.randn(4, 3, 64, 64)
    out = model(x)
    print(out.shape)
    print(model)

test_resnet()

torch.Size([4, 6])
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (stage2): Sequential(
    (0): ConvBlock(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, 