In [2]:
import torch
import torch.nn as nn

In [16]:
class PreActResidualUnit(torch.nn.Module):
  def __init__(self, in_channels, mid_channels, expansion, stride=1):
    super().__init__()
    self.stride = stride
    self.in_channels = in_channels
    self.mid_channels = mid_channels
    self.expansion = expansion

    self.bottleneck = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=stride),

        nn.BatchNorm2d(mid_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1, stride=1),

        nn.BatchNorm2d(mid_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(mid_channels, mid_channels * expansion, kernel_size=1, stride=1)
    )

    self.projection_skip = nn.Sequential(
        nn.Conv2d(in_channels, mid_channels * expansion, kernel_size=1, stride=stride),
        nn.BatchNorm2d(mid_channels * expansion)
    )


  def forward(self, x):
    skip_x = x

    f_x = self.bottleneck(x)
    if self.stride != 1 or self.in_channels != self.mid_channels * self.expansion:
      skip_x = self.projection_skip(x)

    return f_x + skip_x



class ResidualBlock(torch.nn.Module):
  def __init__(self, n_units, in_channels, mid_channels, stride=1):
    super().__init__()

    self.layers = [PreActResidualUnit(in_channels, mid_channels, 4, stride=stride)]

    in_channels = mid_channels * 4

    for unit in range(n_units-1):
      self.layers.append(PreActResidualUnit(in_channels, mid_channels, 4, 1))

    self.block = nn.Sequential(*self.layers)

  def forward(self, x):
    return self.block(x)



class ResNet50BackBone(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.stem = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=2),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )

    self.C2 = ResidualBlock(3, 64, 64)
    self.C3 = ResidualBlock(4, 256, 128)
    self.C4 = ResidualBlock(6, 512, 256)
    self.C5 = ResidualBlock(3, 1024, 512)

  def forward(self, x):
    x = self.stem(x)
    x = self.C2(x)
    x = self.C3(x)
    x = self.C4(x)
    x = self.C5(x)

    return x

In [18]:
ResNet50BackBone()

ResNet50BackBone(
  (stem): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (C2): ResidualBlock(
    (block): Sequential(
      (0): PreActResidualUnit(
        (bottleneck): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU(inplace=True)
          (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (7): ReLU(inplace=True)
          (8): Conv2d(64, 256, kern