<a href="https://colab.research.google.com/github/andricmitrovic/YOLO-object-detection/blob/main/models/model_resnet50%2BCNN_BLOCK_8%2BFC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
from torchvision import models

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
class Yolov1(nn.Module):
    def __init__(self, S, B, C):
        super(Yolov1, self).__init__()
        self.model = self._create_model(S, B, C)

    def forward(self, x):
        x = self.model(x)
        return x

    def _create_model(self, S, B, C):
        model = self._create_darknet()

        num_ftrs = model.fc.in_features
        fcs_layers = self._create_fcs(num_ftrs, S, B, C)

        model.fc = fcs_layers
        return model

    def _create_darknet(self):
        darknet = models.resnet50(pretrained=True)
        ct = 0
        for child in darknet.children():
            ct += 1
            if ct < 8:
                for param in child.parameters():
                    param.requires_grad = False
        return darknet

    def _create_fcs(self, num_ftrs, S, B, C):

        return nn.Sequential(
          nn.Linear(num_ftrs, 512),
          nn.Dropout(0.5),
          nn.LeakyReLU(0.1),
          nn.Linear(512, S * S * (C + B * 5)),
        )
    
    def __str__(self):
        return "resnet50_fine_tunning"

In [7]:
def test(S = 7, B = 2, C = 3):
    model = Yolov1(S, B, C)
    x = torch.randn((2, 3, 224, 224))                  # !!! resnet accepets input in this shape (3 x H x W)
    print(model(x).shape)

In [8]:
# test()