<a href="https://colab.research.google.com/github/adnan119/Pytorch-Projects/blob/main/Object_Detection/YOLO/YOLO_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [45]:
import torch
import torch.nn as nn

In [46]:
architecture_config = [
                       #format: (kernel_size, num_filters, stride, padding)
                       (7, 64, 2, 3),
                       "M",
                       (3, 192, 1, 1),
                       "M",
                       (3, 128, 1, 0),
                       (1, 256, 1, 1),
                       (3, 256, 1, 0),
                       (1, 512, 1, 1),
                       "M",
                       [(1, 256, 1, 0),(3, 512, 1, 1), 4], # 4 = no. of times these two tuples should be repeated
                       (1, 512, 1, 0),
                       (3, 1024, 1, 1),
                       "M",
                       [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
                       (3, 1024, 1, 1),
                       (3, 1024, 2, 1),
                       (3, 1024, 1, 1),
                       (3, 1024, 1, 1),
]

In [47]:
class CNNBlock(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(CNNBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
    self.batchnorm = nn.BatchNorm2d(out_channels)
    self.leakyrelu = nn.LeakyReLU(0.1)

  def forward(self, x):
    return self.leakyrelu(self.batchnorm(self.conv(x)))

In [48]:
class yolov1(nn.Module):
  def __init__(self, in_channels=3, **kwargs):
    super(yolov1, self).__init__()
    self.architecture = architecture_config
    self.in_channels = in_channels
    self.darknet = self._create_conv_layers(self.architecture)
    self.fcs = self._create_fcs(**kwargs)

  def forward(self, x):
    x = self.darknet(x)
    return self.fcs(torch.flatten(x, start_dim=1))

  def _create_conv_layers(self, architecture):
    layers = []
    in_channels = self.in_channels

    for x in architecture:
      if type(x) == tuple:
        layers += [
                   CNNBlock(in_channels, 
                            x[1], 
                            kernel_size = x[0], 
                            stride = x[2], 
                            padding = x[3])
                  ]
        in_channels = x[1]

      elif type(x) == str:
        layers += [nn.MaxPool2d(kernel_size=2, stride = 2)]
      elif type(x) == list:
        conv_1 = x[0]
        conv_2 = x[1]
        num_repeats = x[2]

        for i in range(num_repeats):
          layers += [
                     CNNBlock(in_channels,
                              conv_1[1],
                              kernel_size = conv_1[0],
                              stride = conv_1[2],
                              padding = conv_1[3]),
                     CNNBlock(conv_1[1],
                              conv_2[1],
                              kernel_size = conv_2[0],
                              stride = conv_2[2],
                              padding = conv_2[3])
          ]

          in_channels = conv_2[1]
    return nn.Sequential(*layers)

  def _create_fcs(self, split_size, num_boxes, num_classes):
    S, B, C = split_size, num_boxes, num_classes
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(1024 * S * S, 512),
        nn.Dropout(0.0),
        nn.LeakyReLU(0.1),
        nn.Linear(512, S * S *(C + B * 5)),
    )

In [49]:
def test(S = 7, B = 2, C = 20):
  model = yolov1(split_size = S, num_boxes = B, num_classes = C)
  x = torch.randn((2, 3, 448, 448))
  print(model(x).shape)

In [50]:
test()

torch.Size([2, 1470])


In [51]:
class yololoss(nn.Module):
  def __init__(self, S = 7, B = 2, C = 20):
    super(yololoss, self).__init__()
    self.S = S
    self.B = B
    self.C = C
    self.lambda_coord = 5
    self.lambda_noobj = 0.5

  def forward(self, prediction, target):
    prediction = prediction.reshape(-1, self.S, self.S, self.C + self.B*5)

    iou_b1 = intersection_over_union(prediction[...,21:25], target[...,21:25])
    iou_b2 = intersection_over_union(prediction[...,26:30], target[...,21:25])

    ious = torch.cat([iou_b1.unsqueeze(0),iou_b2.unsqueeze(0)], dim =0)
    iou_maxes, best_box = torch.max(ious, dim=0)
    exists_box = target[..., 20].unsqueeze(3) #Iobj_i
 
    ####  BOX LOSS  ####
    #Box-Coordinates (mid-point, width & height)
    box_predictions = exists_box * (
        (
            best_box * prediction[...,26:30]
            + (1 - best_box) * prediction[...,21:25]
        )
    )

    box_targets = exists_box * target[...,21:25]

    box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
        torch.abs(box_predictions[..., 2:4] + 1e-6)
    )
    
    #box dimensions: (N, S, S, 25)
    box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
    #(N, S, S, 4) -> (N*S*S, 4)
    box_loss = self.mse(torch.flatten(box_predictions, end_dim=-2),
                        torch.flatten(box_targets, end_dim=-2),
                        )
    
    #### OBJECT LOSS ####
    pred_box = (
        best_box * prediction[..., 25:26] + (1 - best_box) * prediction[...,20:21]
    )
    # (N,S,S,1) -> (N*S*S*1)
    object_loss = self.mse(
        torch.flatten(exists_box * pred_box),
        torch.flatten(exists_box * target[...,20:21])
    )

    #### NO-OBJECT LOSS ####
    #(N, S, S, 1) -> (N, S*S*1)
    no_object_loss = self.mse(
        torch.flatten((1 - exists_box) * prediction[..., 20:21], start_dim=1),
        torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
    )

    no_object_loss = self.mse(
        torch.flatten((1 - exists_box) * prediction[..., 25:26], start_dim=1),
        torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
    )

    class_loss = self.mse(
        torch.flatten(exists_box * prediction[...,:20], end_dim = -2),
        torch.flatten(exists_box * target[...,:20], end_dim = -2)
    )

    loss = (
        self.lambda_coord * box_loss 
        + object_loss
        + self.lambda_noobj * no_object_loss
        + class_loss
    )

    return loss