In [66]:
import torch
import torch.nn as nn

In [68]:
%run /content/drive/MyDrive/Colab\ Notebooks/YOLOv1/utils.ipynb

In [85]:
class YoloLoss(nn.Module):
  def __init__(self, S=7, B=2, C=20):
    super(YoloLoss, self).__init__()
    self.mse = nn.MSELoss(reduction="sum")
    self.S = S
    self.B = B
    self.C = C
    self.lambda_noobj = 0.5
    self.lambda_coord = 5
  
  def forward(self, predictions, targets):
    predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)
    iou_b1 = intersection_over_union(predictions[..., 21:25], targets[..., 21:25])
    iou_b2 = intersection_over_union(predictions[..., 26:30], targets[..., 21:25])
    ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)

    _, bestbox = torch.max(ious, dim=0) # bestbox.shape : (Batch, C, C, 1)
    exists_box = targets[..., 20].unsqueeze(3) # exists_box.shape : (Batch, C, C, 1)
    
    # ===================
    # FOR BOX COORDINATES
    # ===================
    box_predictions = exists_box * (  # about i
        bestbox * predictions[..., 26:30]
        + (1 - bestbox) * predictions[..., 21:25]
    )

    box_targets = exists_box * targets[..., 21:25]

    box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
        torch.abs(box_predictions[..., 2:4] + 1e-6)
    )
    box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

    box_loss = self.mse(
        torch.flatten(box_predictions),
        torch.flatten(box_targets)
    )

    # ===============
    # FOR OBJECT LOSS
    # ===============

    pred_box = exists_box * (
        bestbox * predictions[..., 25:26]
        + (1 - bestbox) * predictions[..., 20:21]
    )

    object_loss = self.mse(
        torch.flatten(pred_box),
        torch.flatten(exists_box * targets[..., 20:21])
    )

    # ==================
    # FOR NO OBJECT LOSS
    # ==================

    no_object_loss = self.mse(
        torch.flatten((1 - exists_box) * predictions[..., 20:21]),
        torch.flatten((1 - exists_box) * targets[..., 20:21])
    )
    no_object_loss += self.mse(
        torch.flatten((1 - exists_box) * predictions[..., 25:26]),
        torch.flatten((1 - exists_box) * targets[..., 20:21])
    )

    # ==============
    # FOR CLASS LOSS
    # ==============
    class_loss = self.mse(
        torch.flatten(exists_box * predictions[..., :20]),
        torch.flatten(exists_box * targets[..., :20])
    )

    loss = (
        self.lambda_coord * box_loss
        + object_loss
        + self.lambda_noobj * no_object_loss
        + class_loss
    )
    return loss

In [86]:
preds = torch.randn((16, 7, 7, 30))
targets = torch.abs(torch.randn((16, 7, 7, 30)))

In [87]:
yolo_loss = YoloLoss()

In [88]:
yolo_loss(preds, targets)

tensor(61803.4219)