In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class ResUnit(nn.Module):
    def __init__(self, input_channel:int, output_channel:int, bot_layer: bool=False):
        super().__init__()
        self.bot_layer = bot_layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=3, stride=2 if bot_layer else 1, padding=1),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(output_channel, output_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(output_channel),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=1, stride=2, padding=0),
            nn.BatchNorm2d(output_channel),
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        if self.bot_layer:
            x = self.conv3(x)

        out = F.relu(out + x)
        return out

In [3]:
class ResNet34(nn.Module):
    def __init__(self, input_channel=3, num_classes=10):
        super().__init__()
        self.input_block = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        self.res_block = nn.ModuleList()
        blocks = [3*[64], 4*[128], 6*[256], 3*[512]]

        for index, block in enumerate(blocks):
            module = []
            if index == 0:
                for i in block:
                    module.append(ResUnit(i, i))
                module = nn.Sequential(*module)
            else:
                module.append(ResUnit(block[0]//2, block[0], bot_layer=True))
                for i in block[1:]:
                    module.append(ResUnit(i, i))
                module = nn.Sequential(*module)
            self.res_block.append(module)

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((2,1)),
            nn.Flatten(),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.input_block(x)
        for module in self.res_block:
            x = module(x)
        x = self.classifier(x)
        return x

In [4]:
x = torch.randn(1, 3, 224, 224)
model = ResNet34()
model(x)

tensor([[ 0.7360, -0.3812, -0.2799,  0.3944,  0.7062, -0.3704,  1.3767,  0.5394,
         -0.2246,  0.0983]], grad_fn=<AddmmBackward0>)