In [1]:
# Codeblock 1
import torch
import torch.nn as nn

In [2]:
# Codeblock 2
def intersection_over_union(boxes_targets, boxes_predictions):

    box2_x1 = boxes_targets[..., 0:1] - boxes_targets[..., 2:3] / 2
    box2_y1 = boxes_targets[..., 1:2] - boxes_targets[..., 3:4] / 2
    box2_x2 = boxes_targets[..., 0:1] + boxes_targets[..., 2:3] / 2
    box2_y2 = boxes_targets[..., 1:2] + boxes_targets[..., 3:4] / 2
    
    box1_x1 = boxes_predictions[..., 0:1] - boxes_predictions[..., 2:3] / 2
    box1_y1 = boxes_predictions[..., 1:2] - boxes_predictions[..., 3:4] / 2
    box1_x2 = boxes_predictions[..., 0:1] + boxes_predictions[..., 2:3] / 2
    box1_y2 = boxes_predictions[..., 1:2] + boxes_predictions[..., 3:4] / 2

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)    #(1)

    box1_area = torch.abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = torch.abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    union = box1_area + box2_area - intersection + 1e-6       #(2)

    iou = intersection / union    #(3)

    return iou

In [3]:
# Codeblock 3
target_0 = torch.tensor([[0., 0., 200., 200.]])
pred_0   = torch.tensor([[20., 20., 200., 200.]])
iou_0    = intersection_over_union(target_0, pred_0)
print('iou_0:', iou_0)

target_1 = torch.tensor([[0., 0., 200., 200.]])
pred_1   = torch.tensor([[100., 100., 200., 200.]])
iou_1    = intersection_over_union(target_1, pred_1)
print('iou_1:', iou_1)

target_2 = torch.tensor([[0., 0., 200., 200.]])
pred_2   = torch.tensor([[180., 180., 200., 200.]])
iou_2    = intersection_over_union(target_2, pred_2)
print('iou_2:', iou_2)

iou_0: tensor([[0.6807]])
iou_1: tensor([[0.1429]])
iou_2: tensor([[0.0050]])


In [4]:
# Codeblock 4
sse = nn.MSELoss(reduction="sum")

lambda_coord = 5
lambda_noobj = 0.5

S = 7
B = 2
C = 20

BATCH_SIZE = 1

In [5]:
# Codeblock 5a
def loss(target, prediction):    #(1)
    
    target = target.reshape(-1, S, S, C+5)                #(2)
    prediction = prediction.reshape(-1, S, S, C+B*5)      #(3)

    obj = target[..., 20].unsqueeze(3)      #(4)
    noobj = 1 - obj                         #(5)
    
# Codeblock 5b
    target_bbox = target[..., 21:25]      #(1)
    
    pred_bbox0 = prediction[..., 21:25]   #(2)
    pred_bbox1 = prediction[..., 26:30]   #(3)
    
    iou_pred_bbox0 = intersection_over_union(pred_bbox0, target_bbox)  #(4)
    iou_pred_bbox1 = intersection_over_union(pred_bbox1, target_bbox)  #(5)
    
    iou_pred_bboxes = torch.cat([iou_pred_bbox0.unsqueeze(0), 
                                 iou_pred_bbox1.unsqueeze(0)], 
                                dim=0)
    
    best_iou, best_bbox_idx = torch.max(iou_pred_bboxes, dim=0)    #(6)
    
    target_bbox = obj * target_bbox                                #(7)
    best_bbox   = obj * (best_bbox_idx*pred_bbox1                  #(8)
                         + (1-best_bbox_idx)*pred_bbox0)

    target_bbox[..., 2:4] = torch.sqrt(target_bbox[..., 2:4])      #(9)
    best_bbox[..., 2:4]   = torch.sign(best_bbox[..., 2:4]) * torch.sqrt(torch.abs(best_bbox[..., 2:4]) + 1e-6)  #(10)

    bbox_loss = sse(          #(11)
        torch.flatten(target_bbox, end_dim=-2),
        torch.flatten(best_bbox, end_dim=-2)
    )

    
# Codeblock 5c
    target_bbox_confidence = target[..., 20:21]      #(1)
    pred_bbox0_confidence = prediction[..., 20:21]   #(2)
    pred_bbox1_confidence = prediction[..., 25:26]   #(3)
    
    target_bbox_confidence = obj * target_bbox_confidence                 #(4)
    best_bbox_confidence   = obj * (best_bbox_idx*pred_bbox1_confidence     #(5)
                                    + (1-best_bbox_idx)*pred_bbox0_confidence)
    
    object_loss = sse(      #(6)
        torch.flatten(obj * target_bbox_confidence * best_iou),           #(7)
        torch.flatten(obj * best_bbox_confidence),
    )

    
# Codeblock 5d
    no_object_loss = sse(
        torch.flatten(noobj * target_bbox_confidence),
        torch.flatten(noobj * pred_bbox0_confidence),
    )
    
    no_object_loss += sse(          #(1)
        torch.flatten(noobj * target_bbox_confidence),
        torch.flatten(noobj * pred_bbox1_confidence),
    )
    
    
# Codeblock 5e
    target_class = target[..., :20]      #(1)
    pred_class = prediction[..., :20]    #(2)
    
    
    class_loss = sse(      #(3)
        torch.flatten(obj * target_class, end_dim=-2),
        torch.flatten(obj * pred_class, end_dim=-2),
    )

    
# Codeblock 5f
    total_loss = (
        lambda_coord * bbox_loss           #(1)
        + object_loss
        + lambda_noobj * no_object_loss    #(2)
        + class_loss
    )
    
    return bbox_loss, object_loss, no_object_loss, class_loss, total_loss

In [6]:
# Codeblock 6
def bbox_loss_test():
    target = torch.zeros(BATCH_SIZE, S, S, (C+5))        #(1)
    prediction = torch.zeros(BATCH_SIZE, S, S, (C+B*5))  #(2)
    
    target[0, 3, 3, 21:25] = torch.tensor([0.4, 0.5, 2.4, 3.2])    #(3)
    target[0, 3, 3, 20] = 1.0    #(4)
    target[0, 3, 3, 7] = 1.0     #(5)
    
    prediction[0, 3, 3, 21:25] = torch.tensor([0.4, 0.5, 2.4, 3.2])       #(6)
    #prediction[0, 3, 3, 21:25] = torch.tensor([0.4, 0.5, 2.8, 4.0])      #(7)
    #prediction[0, 3, 3, 21:25] = torch.tensor([0.3, 0.2, 3.2, 4.3])      #(8)
    
    target = target.reshape(BATCH_SIZE, S*S*(C+5))            #(9)
    prediction = prediction.reshape(BATCH_SIZE, S*S*(C+B*5))  #(10)

    bbox_loss = loss(target, prediction)[0]    #(11)
    
    return bbox_loss

bbox_loss_test()

tensor(1.8474e-13)

In [7]:
# Codeblock 7
def object_loss_test():
    target = torch.zeros(BATCH_SIZE, S, S, (C+5))        #(1)
    prediction = torch.zeros(BATCH_SIZE, S, S, (C+B*5))  #(2)
    
    target[0, 3, 3, 21:25] = torch.tensor([0.4, 0.5, 2.4, 3.2])      #(3)
    target[0, 3, 3, 20] = 1.0    #(4)
    target[0, 3, 3, 7] = 1.0     #(5)
    
    prediction[0, 3, 3, 21:25] = torch.tensor([0.4, 0.5, 2.4, 3.2])  #(6)
    
    prediction[0, 3, 3, 20] = 1.0    #(7)
    #prediction[0, 3, 3, 20] = 0.9   #(8)
    #prediction[0, 3, 3, 20] = 0.6   #(9)
    
    target = target.reshape(BATCH_SIZE, S*S*(C+5))
    prediction = prediction.reshape(BATCH_SIZE, S*S*(C+B*5))

    object_loss = loss(target, prediction)[1]
    
    return object_loss

object_loss_test()

tensor(1.4211e-14)

In [8]:
# Codeblock 8
def class_loss_test():
    target = torch.zeros(BATCH_SIZE, S, S, (C+5))
    prediction = torch.zeros(BATCH_SIZE, S, S, (C+B*5))
    
    target[0, 3, 3, 21:25] = torch.tensor([0.4, 0.5, 2.4, 3.2])
    target[0, 3, 3, 20] = 1.0
    target[0, 3, 3, 7] = 1.0
    
    prediction[0, 3, 3, 21:25] = torch.tensor([0.4, 0.5, 2.4, 3.2])
    
    prediction[0, 3, 3, 7] = 1.0    #(1)
    #prediction[0, 3, 3, 7:9] = torch.tensor([0.9, 0.1])    #(2)
    #prediction[0, 3, 3, 7:9] = torch.tensor([0.2, 0.8])    #(3)
    
    target = target.reshape(BATCH_SIZE, S*S*(C+5))
    prediction = prediction.reshape(BATCH_SIZE, S*S*(C+B*5))

    class_loss = loss(target, prediction)[3]
    
    return class_loss

class_loss_test()

tensor(0.)

In [9]:
# Codeblock 9
def no_object_loss_test():
    target = torch.zeros(BATCH_SIZE, S, S, (C+5))
    prediction = torch.zeros(BATCH_SIZE, S, S, (C+B*5))
    
    target[0, 1, 1, 20] = 0.0        #(1)

    prediction[0, 1, 1, 20] = 0.0    #(2)
    prediction[0, 1, 1, 25] = 0.0    #(3)

    #prediction[0, 1, 1, 20] = 0.2   #(4)
    #prediction[0, 1, 1, 25] = 0.3   #(5)
    
    target = target.reshape(BATCH_SIZE, S*S*(C+5))
    prediction = prediction.reshape(BATCH_SIZE, S*S*(C+B*5))

    no_object_loss = loss(target, prediction)[2]
    
    return no_object_loss

no_object_loss_test()

tensor(0.)