<a href="https://colab.research.google.com/github/andricmitrovic/YOLO-object-detection/blob/main/model_resnet18%2BFC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [42]:
import torch
import torch.nn as nn
from torchvision import models

In [43]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [44]:
class Yolov1(nn.Module):
  def __init__(self, **kwargs):
      super(Yolov1, self).__init__()
      self.net = self._create_net(**kwargs)

  def forward(self, x):
      x = self.net(x)
      return x

  def _create_net(self, **kwargs):
      net = models.resnet18(pretrained=True)
      num_ftrs = net.fc.in_features
      fcs_layers = self._create_fcs(num_ftrs, **kwargs)
      net.fc = fcs_layers
      return net

  def _create_fcs(self, num_ftrs, split_size, num_boxes, num_classes):
      S, B, C = split_size, num_boxes, num_classes

      return nn.Sequential(
          nn.Linear(num_ftrs, 4096),
          nn.LeakyReLU(0.1),
          nn.Linear(4096, S * S * (C + B * 5)),
      )

In [45]:
def test(S = 7, B = 2, C = 3):
  model = Yolov1(split_size = S, num_boxes = B, num_classes = C)
  x = torch.randn((2, 3, 224, 224))                  # !!! resnet accepets input in this shape (3 x H x W)
  print(model(x).shape)

In [46]:
test()

torch.Size([2, 637])
