In [None]:
from torch import nn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            inplanes, planes,
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            planes, planes,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    inplanes, self.expansion*planes,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

In [3]:
class BottleneckBlock(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            inplanes, planes,
            kernel_size=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes,
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion*planes,
            kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)
        self.relu = nn.ReLU(inplace=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    inplanes, self.expansion*planes,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

In [4]:
import torch


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()

        self.inplanes = 64
        self.stem = nn.Sequential(
            nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(self.inplanes),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        ) 
        self.stage1 = self._make_layer(block, 64, layers[0], stride=1)
        self.stage2 = self._make_layer(block, 128, layers[1], stride=2)
        self.stage3 = self._make_layer(block, 256, layers[2], stride=2)
        self.stage4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        layers = []
        layers.append(block(self.inplanes, planes, stride))
        self.inplanes = planes * block.expansion
        for _ in range(num_blocks - 1):
            layers.append(block(self.inplanes, planes, 1))
        
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.stem(x)
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

In [None]:
from torchvision import models
from torchinfo import summary

resnet18 = ResNet(BasicBlock, [2, 2, 2, 2], 1000)
resnet34 = ResNet(BasicBlock, [3, 4, 6, 3], 1000)
resnet50 = ResNet(BottleneckBlock, [3, 4, 6, 3], 1000)
resnet101 = ResNet(BottleneckBlock, [3, 4, 23, 3], 1000)
resnet152 = ResNet(BottleneckBlock, [3, 8, 36, 3], 1000)
torch_model = models.resnet34(weights="ResNet34_Weights.IMAGENET1K_V1")

resnet34_info = summary(resnet34, (1, 3, 224, 224), verbose=0)
torch_model_info = summary(torch_model, (1, 3, 224, 224), verbose=0)

print(resnet34_info.total_params)
print(torch_model_info.total_params)

print(resnet34_info)


21797672
21797672
Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Sequential: 1-1                        [1, 64, 56, 56]           --
│    └─Conv2d: 2-1                       [1, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         128
│    └─ReLU: 2-3                         [1, 64, 112, 112]         --
│    └─MaxPool2d: 2-4                    [1, 64, 56, 56]           --
├─Sequential: 1-2                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-5                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 5

In [None]:
# output shape - [1, 64, 56, 56] - batch size, channel, w, h

# (1, 3, 224, 224) 입력 이미지 정보 

# ==========================================================================================
# Layer (type:depth-idx)                   Output Shape              Param #
# ==========================================================================================
# ResNet                                   [1, 1000]                 --
#                                           1 - batch size : 한번에 처리한 이미지 개수 
#                                           1000 - 분류 항목 개수 : 1000 에 대한 확률 출력
# Sequential: 1-1 (입력부): 
#   Conv2d: 이미지를 스캔하여 64개의 특징(Channel)을 추출합니다.
#   MaxPool2d: 이미지 크기를 112 x 112 -> 56 x 56 으로 줄여 연산 효율을 높인다. 

# ├─Sequential: 1-1                        [1, 64, 56, 56]           --
# │    └─Conv2d: 2-1                       [1, 64, 112, 112]         9,408
# │    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         128
# │    └─ReLU: 2-3                         [1, 64, 112, 112]         --
# │    └─MaxPool2d: 2-4                    [1, 64, 56, 56]           --



# ├─Sequential: 1-2                        [1, 64, 56, 56]           --
# │    └─BasicBlock: 2-5                   [1, 64, 56, 56]           --
# │    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
# │    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
# │    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
# │    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
# │    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
# │    │    └─Sequential: 3-6              [1, 64, 56, 56]           --
# │    │    └─ReLU: 3-7                    [1, 64, 56, 56]           --
# │    └─BasicBlock: 2-6                   [1, 64, 56, 56]           --
# │    │    └─Conv2d: 3-8                  [1, 64, 56, 56]           36,864
# │    │    └─BatchNorm2d: 3-9             [1, 64, 56, 56]           128
# │    │    └─ReLU: 3-10                   [1, 64, 56, 56]           --
# │    │    └─Conv2d: 3-11                 [1, 64, 56, 56]           36,864
# │    │    └─BatchNorm2d: 3-12            [1, 64, 56, 56]           128
# │    │    └─Sequential: 3-13             [1, 64, 56, 56]           --
# │    │    └─ReLU: 3-14                   [1, 64, 56, 56]           --
# │    └─BasicBlock: 2-7                   [1, 64, 56, 56]           --
# │    │    └─Conv2d: 3-15                 [1, 64, 56, 56]           36,864
# │    │    └─BatchNorm2d: 3-16            [1, 64, 56, 56]           128
# │    │    └─ReLU: 3-17                   [1, 64, 56, 56]           --
# │    │    └─Conv2d: 3-18                 [1, 64, 56, 56]           36,864
# │    │    └─BatchNorm2d: 3-19            [1, 64, 56, 56]           128
# │    │    └─Sequential: 3-20             [1, 64, 56, 56]           --
# │    │    └─ReLU: 3-21                   [1, 64, 56, 56]           --
# ├─Sequential: 1-3                        [1, 128, 28, 28]          --
# │    └─BasicBlock: 2-8                   [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-22                 [1, 128, 28, 28]          73,728
# │    │    └─BatchNorm2d: 3-23            [1, 128, 28, 28]          256
# │    │    └─ReLU: 3-24                   [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-25                 [1, 128, 28, 28]          147,456
# │    │    └─BatchNorm2d: 3-26            [1, 128, 28, 28]          256
# │    │    └─Sequential: 3-27             [1, 128, 28, 28]          8,448
# │    │    └─ReLU: 3-28                   [1, 128, 28, 28]          --
# │    └─BasicBlock: 2-9                   [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-29                 [1, 128, 28, 28]          147,456
# │    │    └─BatchNorm2d: 3-30            [1, 128, 28, 28]          256
# │    │    └─ReLU: 3-31                   [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-32                 [1, 128, 28, 28]          147,456
# │    │    └─BatchNorm2d: 3-33            [1, 128, 28, 28]          256
# │    │    └─Sequential: 3-34             [1, 128, 28, 28]          --
# │    │    └─ReLU: 3-35                   [1, 128, 28, 28]          --
# │    └─BasicBlock: 2-10                  [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-36                 [1, 128, 28, 28]          147,456
# │    │    └─BatchNorm2d: 3-37            [1, 128, 28, 28]          256
# │    │    └─ReLU: 3-38                   [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-39                 [1, 128, 28, 28]          147,456
# │    │    └─BatchNorm2d: 3-40            [1, 128, 28, 28]          256
# │    │    └─Sequential: 3-41             [1, 128, 28, 28]          --
# │    │    └─ReLU: 3-42                   [1, 128, 28, 28]          --
# │    └─BasicBlock: 2-11                  [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-43                 [1, 128, 28, 28]          147,456
# │    │    └─BatchNorm2d: 3-44            [1, 128, 28, 28]          256
# │    │    └─ReLU: 3-45                   [1, 128, 28, 28]          --
# │    │    └─Conv2d: 3-46                 [1, 128, 28, 28]          147,456
# │    │    └─BatchNorm2d: 3-47            [1, 128, 28, 28]          256
# │    │    └─Sequential: 3-48             [1, 128, 28, 28]          --
# │    │    └─ReLU: 3-49                   [1, 128, 28, 28]          --

# ├─Sequential: 1-4                        [1, 256, 14, 14]          --
# │    └─BasicBlock: 2-12                  [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-50                 [1, 256, 14, 14]          294,912
# │    │    └─BatchNorm2d: 3-51            [1, 256, 14, 14]          512
# │    │    └─ReLU: 3-52                   [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-53                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-54            [1, 256, 14, 14]          512
# │    │    └─Sequential: 3-55             [1, 256, 14, 14]          33,280
# │    │    └─ReLU: 3-56                   [1, 256, 14, 14]          --
# │    └─BasicBlock: 2-13                  [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-57                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-58            [1, 256, 14, 14]          512
# │    │    └─ReLU: 3-59                   [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-60                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-61            [1, 256, 14, 14]          512
# │    │    └─Sequential: 3-62             [1, 256, 14, 14]          --
# │    │    └─ReLU: 3-63                   [1, 256, 14, 14]          --
# │    └─BasicBlock: 2-14                  [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-64                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-65            [1, 256, 14, 14]          512
# │    │    └─ReLU: 3-66                   [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-67                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-68            [1, 256, 14, 14]          512
# │    │    └─Sequential: 3-69             [1, 256, 14, 14]          --
# │    │    └─ReLU: 3-70                   [1, 256, 14, 14]          --
# │    └─BasicBlock: 2-15                  [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-71                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-72            [1, 256, 14, 14]          512
# │    │    └─ReLU: 3-73                   [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-74                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-75            [1, 256, 14, 14]          512
# │    │    └─Sequential: 3-76             [1, 256, 14, 14]          --
# │    │    └─ReLU: 3-77                   [1, 256, 14, 14]          --
# │    └─BasicBlock: 2-16                  [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-78                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-79            [1, 256, 14, 14]          512
# │    │    └─ReLU: 3-80                   [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-81                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-82            [1, 256, 14, 14]          512
# │    │    └─Sequential: 3-83             [1, 256, 14, 14]          --
# │    │    └─ReLU: 3-84                   [1, 256, 14, 14]          --
# │    └─BasicBlock: 2-17                  [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-85                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-86            [1, 256, 14, 14]          512
# │    │    └─ReLU: 3-87                   [1, 256, 14, 14]          --
# │    │    └─Conv2d: 3-88                 [1, 256, 14, 14]          589,824
# │    │    └─BatchNorm2d: 3-89            [1, 256, 14, 14]          512
# │    │    └─Sequential: 3-90             [1, 256, 14, 14]          --
# │    │    └─ReLU: 3-91                   [1, 256, 14, 14]          --

# ├─Sequential: 1-5                        [1, 512, 7, 7]            --
# │    └─BasicBlock: 2-18                  [1, 512, 7, 7]            --
# │    │    └─Conv2d: 3-92                 [1, 512, 7, 7]            1,179,648
# │    │    └─BatchNorm2d: 3-93            [1, 512, 7, 7]            1,024
# │    │    └─ReLU: 3-94                   [1, 512, 7, 7]            --
# │    │    └─Conv2d: 3-95                 [1, 512, 7, 7]            2,359,296
# │    │    └─BatchNorm2d: 3-96            [1, 512, 7, 7]            1,024
# │    │    └─Sequential: 3-97             [1, 512, 7, 7]            132,096
# │    │    └─ReLU: 3-98                   [1, 512, 7, 7]            --
# │    └─BasicBlock: 2-19                  [1, 512, 7, 7]            --
# │    │    └─Conv2d: 3-99                 [1, 512, 7, 7]            2,359,296
# │    │    └─BatchNorm2d: 3-100           [1, 512, 7, 7]            1,024
# │    │    └─ReLU: 3-101                  [1, 512, 7, 7]            --
# │    │    └─Conv2d: 3-102                [1, 512, 7, 7]            2,359,296
# │    │    └─BatchNorm2d: 3-103           [1, 512, 7, 7]            1,024
# │    │    └─Sequential: 3-104            [1, 512, 7, 7]            --
# │    │    └─ReLU: 3-105                  [1, 512, 7, 7]            --
# │    └─BasicBlock: 2-20                  [1, 512, 7, 7]            --
# │    │    └─Conv2d: 3-106                [1, 512, 7, 7]            2,359,296
# │    │    └─BatchNorm2d: 3-107           [1, 512, 7, 7]            1,024
# │    │    └─ReLU: 3-108                  [1, 512, 7, 7]            --
# │    │    └─Conv2d: 3-109                [1, 512, 7, 7]            2,359,296
# │    │    └─BatchNorm2d: 3-110           [1, 512, 7, 7]            1,024
# │    │    └─Sequential: 3-111            [1, 512, 7, 7]            --
# │    │    └─ReLU: 3-112                  [1, 512, 7, 7]            --
# ├─AdaptiveAvgPool2d: 1-6                 [1, 512, 1, 1]            --
# ├─Linear: 1-7                            [1, 1000]                 513,000
# ==========================================================================================
# Total params: 21,797,672
# Trainable params: 21,797,672
# Non-trainable params: 0
# Total mult-adds (G): 3.66
# ==========================================================================================
# Input size (MB): 0.60
# Forward/backward pass size (MB): 59.82
# Params size (MB): 87.19
# Estimated Total Size (MB): 147.61
# ==========================================================================================