In [10]:
from model import intersection_over_union
import torch.nn as nn
import torch
import config

In [9]:
def forward(preds, target, anchors):
    mse = nn.MSELoss()
    bcewll = nn.BCEWithLogitsLoss()
    cross_entropy = nn.CrossEntropyLoss()
    sigmoid_function = nn.Sigmoid()
    l_class = 1
    l_nj = 10
    l_box = 10
    l_obj = 1

    # Check where obj and nj (we ignore if target == -1)
    obj = target[..., 0] == 1  # in paper this is Iobj_i
    nj = target[..., 0] == 0  # in paper this is Inoobj_i

    # ======================= #
    #   FOR NO OBJECT LOSS    #
    # ======================= #

    no_object_loss = bcewll(
        (preds[..., 0:1][nj]), (target[..., 0:1][nj]),
    )

    # ==================== #
    #   FOR OBJECT LOSS    #
    # ==================== #

    anchors = anchors.reshape(1, 3, 1, 1, 2)

    box_preds = torch.cat([sigmoid_function(preds[..., 1:3]), torch.exp(preds[..., 3:5]) * anchors], dim=-1)
    result = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()

    loss_obj = mse(sigmoid_function(preds[..., 0:1][obj]), result * target[..., 0:1][obj])

    # ======================== #
    #   FOR BOX COORDINATES    #
    # ======================== #

    preds[..., 1:3] = sigmoid_function(preds[..., 1:3])  # x,y coordinates
    target[..., 3:5] = torch.log(
        (1e-16 + target[..., 3:5] / anchors)
    )  # width, height coordinates
    box_loss = mse(preds[..., 1:5][obj], target[..., 1:5][obj])

    # ================== #
    #   FOR CLASS LOSS   #
    # ================== #

    class_loss = cross_entropy(
        (preds[..., 5:][obj]), (target[..., 5][obj].long()),
    )

    return (
            l_box * box_loss
            + l_obj * loss_obj
            + l_nj * no_object_loss
            + l_class * class_loss
    )

In [52]:
scaled_anchors = (
        torch.tensor(config.ANCHORS)
        * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    )
scaled_anchors

tensor([[[ 3.6400,  2.8600],
         [ 4.9400,  6.2400],
         [11.7000, 10.1400]],

        [[ 1.8200,  3.9000],
         [ 3.9000,  2.8600],
         [ 3.6400,  7.5400]],

        [[ 1.0400,  1.5600],
         [ 2.0800,  3.6400],
         [ 4.1600,  3.1200]]])

In [58]:
rnd_t = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
        [[10, 11, 12], [13, 14, 15], [16, 17, 18]]])

In [60]:
rnd_t[...,0]

tensor([[ 1,  4,  7],
        [10, 13, 16]])