In [70]:
import torch
import torch.nn as nn

from torchsummary import summary

In [73]:
class ResidualBlock(nn.Module):
    def __init__(self, in_c, out_c, is_downsample=False):
        super().__init__()

        if is_downsample:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU()
            )
            self.shortcut_proj = nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=1, stride=2, padding=0, bias=False),
                nn.BatchNorm2d(out_c),
            )

        else:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU()
            )
            self.shortcut_proj = None

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
        )

        self.relu = nn.ReLU()


    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.conv2(x)

        if self.shortcut_proj:
            identity = self.shortcut_proj(identity)

        x += identity
        x = self.relu(x)

        return x

In [56]:
# inp = torch.randn(1, 64, 56, 56)
# print(inp.shape)

# res_block = ResidualBlock(in_c=64, out_c=128, is_downsample=True)
# out = res_block(inp)
# print(out.shape)

In [74]:
class resnet34(nn.Module):
    def __init__(self, in_c=3, out_classes=1000):
        super().__init__()

        # conv1
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        # conv2_x
        self.conv2_maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2_1 = ResidualBlock(in_c=64, out_c=64)
        self.conv2_2 = ResidualBlock(in_c=64, out_c=64)
        self.conv2_3 = ResidualBlock(in_c=64, out_c=64)

        # conv3_x
        self.conv3_1 = ResidualBlock(in_c=64, out_c=128, is_downsample=True)
        self.conv3_2 = ResidualBlock(in_c=128, out_c=128)
        self.conv3_3 = ResidualBlock(in_c=128, out_c=128)
        self.conv3_4 = ResidualBlock(in_c=128, out_c=128)

        # conv4_x
        self.conv4_1 = ResidualBlock(in_c=128, out_c=256, is_downsample=True)
        self.conv4_2 = ResidualBlock(in_c=256, out_c=256)
        self.conv4_3 = ResidualBlock(in_c=256, out_c=256)
        self.conv4_4 = ResidualBlock(in_c=256, out_c=256)
        self.conv4_5 = ResidualBlock(in_c=256, out_c=256)
        self.conv4_6 = ResidualBlock(in_c=256, out_c=256)

        # conv5_x
        self.conv5_1 = ResidualBlock(in_c=256, out_c=512, is_downsample=True)
        self.conv5_2 = ResidualBlock(in_c=512, out_c=512)
        self.conv5_3 = ResidualBlock(in_c=512, out_c=512)

        # classifier
        self.gap = nn.AvgPool2d(kernel_size=7)
        self.classifier = nn.Linear(512, 1000, bias=True)


    def forward(self, x):
        # conv1
        x = self.conv1(x)
        
        # conv2_x
        x = self.conv2_maxpool(x)
        x = self.conv2_1(x)
        x = self.conv2_2(x)
        x = self.conv2_3(x)

        x = self.conv3_1(x)
        x = self.conv3_2(x)
        x = self.conv3_3(x)
        x = self.conv3_4(x)     

        x = self.conv4_1(x)
        x = self.conv4_2(x)
        x = self.conv4_3(x)
        x = self.conv4_4(x)
        x = self.conv4_5(x)
        x = self.conv4_6(x)

        x = self.conv5_1(x)
        x = self.conv5_2(x)
        x = self.conv5_3(x)

        x = self.gap(x)
        x = x.view(-1, 512)

        out = self.classifier(x)

        return out

In [75]:
# N, C, H, W
inp = torch.randn(1, 3, 224, 224)
print(inp.shape)

net = resnet34()
out = net(inp)

print(out.shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])


In [76]:
summary(net, input_size=(3, 224, 224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
    ResidualBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,