# A bit about how YOLO-V1 Loss works

The loss function used in Yolo-v1 conjsists of a weighted sum of a number of squared error terms. 

- First, we Take the sum of the squared error in the midpoint for each box in each cell, but only the box which is responsible for predicting the object is considred (That is the box with the highest IOU).

- Second we take the sum of the squared error of the square roots of the of the wdith and the height. Again only for the responsible box.

- Third we take the sum of the sqaured error of the probability that there is an object. Again only for the responsible box.

- Fourth we take the sum of the sqaured error of the probability that there is no object.

- Finally, for each cell, if there is an object, we take the sum of the squared error for each class prediuction vs the actual prediction.

This is easier to understand in the below code.

In [2]:
import torch
import torch.nn as nn
import numpy as np



In [3]:
def intersection_over_union(boxes_preds, boxes_labels):
  

    box1_x1 = boxes_preds[..., 0] - boxes_preds[..., 2] / 2
    box1_y1 = boxes_preds[..., 1] - boxes_preds[..., 3] / 2
    box1_x2 = boxes_preds[..., 0] + boxes_preds[..., 2] / 2
    box1_y2 = boxes_preds[..., 1] + boxes_preds[..., 3] / 2

    box2_x1 = boxes_labels[..., 0] - boxes_labels[..., 2] / 2
    box2_y1 = boxes_labels[..., 1] - boxes_labels[..., 3] / 2
    box2_x2 = boxes_labels[..., 0] + boxes_labels[..., 2] / 2
    box2_y2 = boxes_labels[..., 1] + boxes_labels[..., 3] / 2


    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)


    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)


    box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1)
    box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1)

    union = box1_area + box2_area - intersection

    iou = intersection / (union + 1e-6)  # dd a small epsilon to avoid division by zero
    return iou

In [4]:
class Yolo_Loss(nn.Module):
    def __init__(self, S= 7, B = 2, C = 20):
        super(Yolo_Loss, self).__init__()

        self.mse = nn.MSELoss(reduction="sum")
        self.S = S
        self.B = B
        self.C = C
        self.lambda_noobj = 0.5
        self.lambda_coord = 5



    def forward(self, predictions, target):

        predictions = predictions.reshape(-1, self.S, self.S, self.C + 5*self.B)

        iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
        iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 26:30])

        ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)

        iou_max, best_box = torch.max(ious, dim=0)
        exists_box = target[..., 20].unsqueeze(3)

        # ======================= #
        # Box loss (midpoint and scale)
        # ======================= #
        best_box = best_box.unsqueeze(-1)


        box_predictions = exists_box*(best_box * predictions[..., 26:30] + (1- best_box)*predictions[..., 21:25]) # vector of coordinates

        box_targets = exists_box*target[..., 21:25]

        box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4])*torch.sqrt(torch.abs(box_predictions[..., 2:4] + 1e-6)) # The tiny amount is added for some stability. The derivative of the sqrt is infinity at 0

        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

        box_loss = self.mse(torch.flatten(box_predictions, end_dim=-2), torch.flatten(box_targets, end_dim=-2))


        # ======================= #
        # Object Loss
        # ======================= #

        pred_box = (best_box* predictions[..., 25:26] + (1 - best_box)*predictions[..., 20:21])

        object_loss = self.mse(torch.flatten(exists_box*pred_box), torch.flatten(exists_box*target[..., 20:21]))



        # ======================= #
        # No Object Loss
        # ======================= #

        no_object_loss = self.mse(torch.flatten((1 - exists_box)*predictions[..., 20:21], start_dim=1), torch.flatten((1 - exists_box)*target[..., 25:26], start_dim=1))

        no_object_loss += self.mse(torch.flatten((1 - exists_box)*predictions[..., 20:21], start_dim=1), torch.flatten((1 - exists_box)*target[..., 20:21], start_dim=1))


        # ======================= #
        # Class Loss
        # ======================= #

        class_loss = self.mse(torch.flatten(exists_box * predictions[..., :20], end_dim = -2), torch.flatten(exists_box * target[..., :20], end_dim = -2))




        # ======================= #
        # YOLO Loss
        # ======================= #


        loss = self.lambda_coord*box_loss + object_loss + self.lambda_noobj*no_object_loss + class_loss


        return loss





In [None]:
# Small test case