In [3]:
import numpy as np
import torch.nn as nn
import torch
from utils import bbox_iou

In [None]:
class Yolo_Loss(nn.Module):
    def __init__(self):
        super().__init__()
        # losses and functions
        self.bcwell = nn.BCEWithLogitsLoss()
        self.mse = nn.MSELoss()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()

        # lambda constants
        self.lambda_class = 1
        self.lambda_noobj = 10
        self.lambda_box = 10
        self.lambda_obj = 1


    def forward(self, prediction, target, anchors):
        # check objectness for identity function
        obj = target[None] == 1 # I^obj_ij
        noobj = target[None] == 0 # I^noobj_ij

        ## box coordinate loss
        prediction[None] = self.sigmoid(prediction[None]) # currently x,y coords
        target[None] = torch.log((1e-16 + target[None]/anchors)) # width and height coords
        bbox_coord_loss = self.mse(prediction[None][obj], target[None][obj])

        ## object loss
        anchors = anchors.reshape(1,3,1,1,2)
        box_preds = torch.cat([self.sigmoid(prediction[None]), torch.exp(prediction[None])*anchors], dim=-1)
        result = bbox_iou(box_preds[obj], target[None][obj]).detach()
        obj_loss = self.mse(self.sigmoid(prediction[None][obj]), result*target[None][obj])

        ## no object loss
        noobj_loss = self.bcwell((prediction[None][noobj]), (target[None][noobj]))

        ## class loss
        class_loss = self.cross_entropy((prediction[None][obj]), (target[None][obj].long()))

        loss = self.lambda_box*bbox_coord_loss + self.lambda_obj*obj_loss + self.lambda_noobj*noobj_loss + self.lambda_class*class_loss

        return loss