In [1]:
import torch
import torch.nn as nn

In [76]:
architecture_config = [
  # kernel_size, output_channels, stride, padding
    (7, 64, 2, 3),
    "M",
    (3, 192, 1, 1),
    "M",
    (1, 128, 1, 0),
    (3, 256, 1, 1),
    (1, 256, 1, 0),
    (3, 512, 1, 1),
    "M",
    [(1, 256, 1, 0), (3, 512, 1, 1), 4],  # last integer represents the number of repeats
    (1, 512, 1, 0),
    (3, 1024, 1, 1),
    "M",
    [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
    (3, 1024, 1, 1),
    (3, 1024, 2, 1),
    (3, 1024, 1, 1),
    (3, 1024, 1, 1),
]


In [69]:
class CNNBlock(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CNNBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
    self.batchnorm = nn.BatchNorm2d(out_channels)
    self.leakyrelu = nn.LeakyReLU(0.1)

  def forward(self, x):
    return self.leakyrelu(self.batchnorm(self.conv(x)))

In [85]:
class Yolov1(nn.Module):
  def __init__(self, in_channels=3, **kwargs):
    super(Yolov1, self).__init__()
    self.architecture = architecture_config
    self.in_channels = in_channels
    self.darknet = self._create_conv_layers(self.architecture)
    self.fcs = self._create_fcs(**kwargs)
  
  def forward(self, x):
    x = self.darknet(x)
    return self.fcs(torch.flatten(x, 1))
  
  def _create_conv_layers(self, architecture):
    layers = []
    in_channels = self.in_channels

    for x in architecture:
      if type(x) == tuple:
        layers += [
          CNNBlock(
            in_channels, x[1], kernel_size=x[0], stride=x[2], padding=x[3],
          )
        ]
        in_channels = x[1]
      elif type(x) == str:
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      elif type(x) == list:
        conv1 = x[0]
        conv2 = x[1]
        num_repeats = x[2]
        for _ in range(num_repeats):
          layers += [
                     CNNBlock(
                        in_channels,
                        conv1[1],
                        kernel_size=conv1[0],
                        stride=conv1[2],
                        padding=conv1[3],
                     )
          ]
          layers += [
                     CNNBlock(
                        conv1[1],
                        conv2[1],
                        kernel_size=conv2[0],
                        stride=conv2[2],
                        padding=conv2[3],
                     )
          ]
          in_channels = conv2[1]
    return nn.Sequential(*layers)
  
  def _create_fcs(self, split_size, num_boxes, num_classes):
    S, B, C = split_size, num_boxes, num_classes
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(1024 * S * S, 496), # 4096 -> 496
        nn.Dropout(0.5),
        nn.LeakyReLU(0.1),
        nn.Linear(496, S * S * (C + B * 5)), # (S, S, 30) where C + B * 5 == 30
    )

def test(S=7, B=2, C=20):
  model = Yolov1(split_size=S, num_boxes=B, num_classes=C)
  x = torch.randn((4, 3, 448, 448))
  ret = model(x)
  print(ret.shape)
  return ret

In [87]:
ret = test()
ret.reshape((-1, 7, 7, 30))

torch.Size([4, 1470])


tensor([[[[ 3.0007e-01,  4.5405e-01,  1.5534e-01,  ..., -1.1476e-01,
           -2.3390e-01, -3.1925e-02],
          [-5.1170e-01,  2.4469e-01, -1.4770e-02,  ..., -1.4767e-01,
           -1.1443e-01,  1.3856e-01],
          [-1.2273e-01,  1.2067e-01,  1.6377e-01,  ...,  2.8276e-01,
            1.2618e-01,  6.0360e-01],
          ...,
          [-3.0519e-01, -2.7465e-01, -1.7495e-01,  ...,  1.3903e-01,
           -2.7004e-01, -7.1632e-02],
          [ 4.8382e-02,  2.8441e-02, -1.4999e-01,  ...,  2.4577e-01,
           -3.3099e-01,  1.7906e-01],
          [ 2.6578e-03,  3.5526e-01, -1.3390e-01,  ...,  3.9144e-01,
            3.1011e-01, -2.7495e-01]],

         [[ 1.8188e-01,  3.8515e-02, -3.5500e-01,  ..., -1.7852e-01,
           -2.1266e-01, -2.7858e-01],
          [ 1.0862e-01,  4.9627e-01,  1.5595e-01,  ...,  1.5274e-01,
           -2.5286e-01,  1.7241e-01],
          [ 2.4853e-01, -2.7765e-01,  3.3553e-01,  ..., -4.0882e-01,
           -4.9007e-01, -3.8249e-01],
          ...,
     