In [None]:
import torch
from torch import nn

In [None]:
class BasicBlock(nn.Module):
  expansion = 1 # 한 블럭 내에서 입출력간의 채널의 배수 (basic은 입출력 같음)
  def __init__(self, in_channels, inner_channels, stride = 1, projection=None):
    super().__init__()
    self.residual = nn.Sequential(nn.Conv2d(in_channels, inner_channels, 3, stride=stride, padding=1, bias=False),
                                  nn.BatchNorm2d(inner_channels),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(inner_channels, inner_channels * self.expansion, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(inner_channels * self.expansion))
    self.projection = projection
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    residual = self.residual(x)
    if self.projection is not None:
      shortcut = self.projection(x)
    else:
      shortcut = x
    return self.relu(residual + shortcut)

class BottleneckBlock(nn.Module):
  expansion = 4

  def __init__(self, in_channels, inner_channels, stride=1, projection=None):
    super().__init__()
    self.residual = nn.Sequential(nn.Conv2d(in_channels, inner_channels, 1, bias=False),
                                  nn.BatchNorm2d(inner_channels),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(inner_channels, inner_channels, 3, stride=stride, padding=1, bias=False),
                                  nn.BatchNorm2d(inner_channels),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(inner_channels, inner_channels * self.expansion, 1, bias=False),
                                  nn.BatchNorm2d(inner_channels * self.expansion))
    self.projection = projection
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    residual = self.residual(x)
    if self.projection is not None:
      shortcut = self.projection(x)
    else:
      shortcut = x
    return self.relu(residual + shortcut)

class ResNet(nn.Module):
  def __init__(self, block, num_block_list, num_classes = 1000, zero_init_residual = True):
    super().__init__()
    self.in_channels = 64
    self.conv1 = nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
    self.stage1 = self.make_stage(block, 64, num_block_list[0], stride = 1)
    self.stage2 = self.make_stage(block, 128, num_block_list[1], stride = 2)
    self.stage3 = self.make_stage(block, 256, num_block_list[2], stride = 2)
    self.stage4 = self.make_stage(block, 512, num_block_list[3], stride = 2)
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512 * block.expansion, num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

    if zero_init_residual:
      for m in self.modules():
        if isinstance(m. block):
          nn.init.constant_(m.residual[-1].weight, 0)



  def make_stage(self, block, inner_channels, num_blocks, stride=1):
    # 한 residual이 아닌 여러 residual이 포함된 stage를 생성
    # stage의 첫 번째 layer는 사이즈, 채널 이슈때문에 projection(점선연결) + prj(x)
    # 그 외는 projection이 아님(실선연결) + x
    if stride != 1 or self.in_channels != inner_channels * block.expansion:
      # size가 줄거나 or 입출력 채널 수가 달라지거나
      # stage1에서는 stride가 1이므로 뒤에 조건을 추가해야 projection을 만들 수 있음
      projection = nn.Sequential(nn.Conv2d(self.in_channels, inner_channels * block.expansion, 1, stride=stride, bias=False),
                                 nn.BatchNorm2d(inner_channels * block.expansion))
    else:
      projection = None

    layers = []
    layers += [block(self.in_channels, inner_channels, stride , projection)]
    self.in_channels = inner_channels * block.expansion
    for _ in range(1, num_blocks):
      layers += [block(self.in_channels, inner_channels)]

    return nn.Sequential(*layers)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.stage1(x)
    x = self.stage2(x)
    x = self.stage3(x)
    x = self.stage4(x)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    out = self.fc(x)
    return out

In [None]:
def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)

def resnet50(**kwargs):
    return ResNet(BottleneckBlock, [3, 4, 6, 3], **kwargs)

def resnet101(**kwargs):
    return ResNet(BottleneckBlock, [3, 4, 23, 3], **kwargs)

def resnet152(**kwargs):
    return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)

In [None]:
#!pip install torchinfo
model = resnet50()
from torchinfo import summary
summary(model, input_size=(2,3,224,224), device='cpu')

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [2, 1000]                 --
├─Conv2d: 1-1                            [2, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [2, 64, 112, 112]         128
├─ReLU: 1-3                              [2, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [2, 64, 56, 56]           --
├─Sequential: 1-5                        [2, 256, 56, 56]          --
│    └─BottleneckBlock: 2-1              [2, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [2, 256, 56, 56]          58,112
│    │    └─Sequential: 3-2              [2, 256, 56, 56]          16,896
│    │    └─ReLU: 3-3                    [2, 256, 56, 56]          --
│    └─BottleneckBlock: 2-2              [2, 256, 56, 56]          --
│    │    └─Sequential: 3-4              [2, 256, 56, 56]          70,400
│    │    └─ReLU: 3-5                    [2, 256, 56, 56]          --

In [None]:
model = resnet152()

summary(model, input_size=(2,3,224,224), device='cpu')

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [2, 1000]                 --
├─Conv2d: 1-1                            [2, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [2, 64, 112, 112]         128
├─ReLU: 1-3                              [2, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [2, 64, 56, 56]           --
├─Sequential: 1-5                        [2, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [2, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [2, 256, 56, 56]          58,112
│    │    └─Sequential: 3-2              [2, 256, 56, 56]          16,896
│    │    └─ReLU: 3-3                    [2, 256, 56, 56]          --
│    └─Bottleneck: 2-2                   [2, 256, 56, 56]          --
│    │    └─Sequential: 3-4              [2, 256, 56, 56]          70,400
│    │    └─ReLU: 3-5                    [2, 256, 56, 56]          --