<a href="https://colab.research.google.com/github/andricmitrovic/YOLO-object-detection/blob/main/model_resnet18%2BFC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torchvision import models

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class Yolov1(nn.Module):
    def __init__(self, **kwargs):
        super(Yolov1, self).__init__()
        self.model = self._create_model(**kwargs)

    def forward(self, x):
        x = self.model(x)
        return x

    def _create_model(self, **kwargs):
        model = self._create_darknet()

        num_ftrs = model.fc.in_features
        fcs_layers = self._create_fcs(num_ftrs, **kwargs)

        model.fc = fcs_layers
        return model

    def _create_darknet(self):
        darknet = models.resnet18(pretrained=True)
        for param in darknet.parameters():
            param.requires_grad = False
        return darknet

    def _create_fcs(self, num_ftrs, split_size, num_boxes, num_classes):
        S, B, C = split_size, num_boxes, num_classes

        return nn.Sequential(
          nn.Linear(num_ftrs, 4096),
          nn.Dropout(0.5),
          nn.LeakyReLU(0.1),
          nn.Linear(4096, S * S * (C + B * 5)),
        )

In [4]:
def test(S = 7, B = 2, C = 3):
    model = Yolov1(split_size = S, num_boxes = B, num_classes = C)
    x = torch.randn((2, 3, 224, 224))                  # !!! resnet accepets input in this shape (3 x H x W)
    print(model(x).shape)

In [5]:
# test()