In [75]:
import torch.nn as nn
import torch
import config
import numpy as np
from model import intersection_over_union

In [129]:
def bbox_iou(box1: torch.FloatTensor, box2: torch.FloatTensor):
    """Returns intersection over union of two bounding boxes.

    Strictly performed on tensors.

    Args:
        box1 (torch.FloatTensor): Coordinates of bbox 1.
        box2 (torch.FloatTensor): Coordinates of bbox 2.

    Returns:
        iou (float): IOU of two input bboxes.
    """
    # get coords of bboxes
    b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]
    b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]

    # get coords of intersection
    intersect_x1 = torch.max(b1_x1, b2_x1)
    intersect_y1 = torch.max(b1_y1, b2_y1)
    intersect_x2 = torch.min(b1_x2, b2_x2)
    intersect_y2 = torch.min(b1_y2, b2_y2)

    # intersection area
    # clamp to > 0
    # this avoids areas being calculated for boxes with zero intersect
    intersect_area = torch.clamp(intersect_x2 - intersect_x1, min=0)*torch.clamp(intersect_y2 - intersect_y1, min=0)

    # union area
    b1_area = (b1_x2 - b1_x1)*(b1_y2 - b1_y1)
    b2_area = (b2_x2 - b2_x1)*(b2_y2 - b2_y1)
    union_area = b1_area + b2_area - intersect_area

    # compute iou
    iou = intersect_area/union_area

    return iou

In [130]:
def centre_dims_to_corners(bbox: np.ndarray) -> np.ndarray:
    """Converts bbox attributes of form [x_centre, y_centre, width, height] to form [x1, y1, x2, y2]. 
    
    Use on an array of bboxes: [[bbox_1], [bbox_2], ... [bbox_n]].

    This form is used for easily calculating 2 bbox's IoU.

    Args:
        bbox (np.ndarray): Bbox centre and dims [x_centre, y_centre, width, height].

    Returns:
        new_bbox (np.ndarray): Bbox corner coords [x1, y1, x2, y2].
    """
    if len(bbox.shape) > 1:
        x_c, y_c, w, h = bbox[:,0], bbox[:,1], bbox[:,2], bbox[:,3]
    else:
        x_c, y_c, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
    x1, x2 = x_c-(w/2), x_c+(w/2)
    y1, y2 = y_c-(h/2), y_c+(h/2)
    
    if len(bbox.shape) > 1:
        x1 = np.expand_dims(x1, 1)
        x2 = np.expand_dims(x2, 1)
        y1 = np.expand_dims(y1, 1)
        y2 = np.expand_dims(y2, 1)

        new_bbox = np.concatenate((x1, y1, x2, y2), axis=1)
    else:
        new_bbox = np.array([x1, y1, x2, y2])

    return new_bbox

# From Scratch 1 with fake data

In [54]:
# losses
bcewll = nn.BCEWithLogitsLoss()
bce = nn.BCELoss()
mse = nn.MSELoss()
ce = nn.CrossEntropyLoss()
sigmoid_function = nn.Sigmoid()

In [116]:
target = torch.tensor([
    [1, 25, 25, 50, 50, 0, 0, 1], 
    [0, 0, 0, 0, 0, 0, 0, 0],
    [1, 50, 50, 50, 50, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0]]) # 2x2 gridsize [obj, xc, yc, w, h, cls...]

good_pred = torch.tensor([
    [0.9, 24, 26.5, 48, 47, 0.01, 0.01, 0.98], 
    [0.2, 0, 0, 0, 0, 0, 0, 0],
    [0.97, 51, 52, 49, 51, 0.02, 0.98, 0],
    [0.1, 0, 0, 0, 0, 0, 0, 0]])

bad_pred = torch.tensor([
    [0.2, 2, 265, 8, 27, 0.3, 0.5, 0.2], 
    [0.9, 0, 0, 0, 0, 0, 0, 0],
    [0.1, 50, 2, 492, 81, 0.7, 0.2, 0.1],
    [0.97, 0, 0, 0, 0, 0, 0, 0]]) 

obj = (target[...,0] == 1)
nj = (target[...,0] == 0)

### No Obj Loss

In [42]:
good_no_object_loss_wll = bcewll(
    (good_pred[..., 0:1][nj].float()), (target[..., 0:1][nj].float()),
)

bad_no_object_loss_wll = bcewll(
    (bad_pred[..., 0:1][nj].float()), (target[..., 0:1][nj].float()),
)

good_no_object_loss = bce(
    (good_pred[..., 0:1][nj].float()), (target[..., 0:1][nj].float()),
)

bad_no_object_loss = bce(
    (bad_pred[..., 0:1][nj].float()), (target[..., 0:1][nj].float()),
)


print(f"good bcewll:    {good_no_object_loss_wll}")
print(f"bad bcewll:     {bad_no_object_loss_wll}")
print(f"good bce:       {good_no_object_loss}")
print(f"bad bce:        {bad_no_object_loss}")


good bcewll:    0.7712677717208862
bad bcewll:     1.2662863731384277
good bce:       0.16425204277038574
bad bce:        2.904572010040283


### Obj Loss

- already applied sigmoid to prediction x_c, y_c
- already applied exp to prediction w, h
- already scaled prediction w, h by anchors

In [70]:
# preds already scaled by anchors
scaled_anchors = (
        torch.tensor(config.ANCHORS)
        * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    )
anchors = scaled_anchors[0].reshape(1, 3, 1, 1, 2) # 3 anchors go in at a time

In [55]:
sigmoid_function(good_pred[..., 1:3])

tensor([[1.0000, 1.0000],
        [0.5000, 0.5000],
        [1.0000, 1.0000],
        [0.5000, 0.5000]])

In [137]:
target_bbox = torch.tensor([centre_dims_to_corners(t) for t in target[...,1:5][obj]])
good_pred_bbox = torch.tensor([centre_dims_to_corners(t) for t in good_pred[...,1:5][obj]])
bad_pred_bbox = torch.tensor([centre_dims_to_corners(t) for t in bad_pred[...,1:5][obj]])

In [153]:
result = bbox_iou(target_bbox, good_pred_bbox)
result

tensor([0.9024, 0.8887])

In [152]:
target[..., 0][obj]*result

tensor([0.9024, 0.8887])

In [155]:
good_pred[..., 0:1][obj].T

tensor([[0.9000, 0.9700]])

In [157]:
mse(good_pred[..., 0:1][obj], result*target[..., 0:1][obj])

  return F.mse_loss(input, target, reduction=self.reduction)


tensor(0.0028)

In [81]:
box_preds[obj]

tensor([[1.0000e+00, 1.0000e+00, 7.0167e+20, 2.5813e+20],
        [1.0000e+00, 1.0000e+00, 1.9073e+21, 1.4093e+22]])

In [82]:
target[..., 1:5][obj]

tensor([[25, 25, 50, 50],
        [50, 50, 50, 50]])

In [128]:

good_object_loss = bce(
    (good_pred[..., 0:1][obj].float()), (target[..., 0:1][obj].float()),
)

bad_object_loss = bce(
    (bad_pred[..., 0:1][obj].float()), (target[..., 0:1][obj].float()),
)


print(f"good bcewll:    {good_object_loss_wll}")
print(f"bad bcewll:     {bad_object_loss_wll}")
print(f"good bce:       {good_object_loss}")
print(f"bad bce:        {bad_object_loss}")


good bcewll:    0.33128637075424194
bad bcewll:     0.6212677955627441
good bce:       0.0679098591208458
bad bce:        1.9560115337371826


### Bbox Loss

- already applied sigmoid to prediction x_c, y_c
- already applied exp and anchor scale to prediction w, h
    -> so dont apply reverse operations on target/labels

In [114]:
good_pred[..., 1:5][obj]

tensor([[24.0000, 26.5000, 48.0000, 47.0000],
        [51.0000, 52.0000, 49.0000, 51.0000]])

In [119]:
torch.sqrt(good_pred[..., 1:5][obj])

tensor([[4.8990, 5.1478, 6.9282, 6.8557],
        [7.1414, 7.2111, 7.0000, 7.1414]])

In [115]:
target[..., 1:5][obj]

tensor([[25, 25, 50, 50],
        [50, 50, 50, 50]])

In [120]:
torch.sqrt(target[..., 1:5][obj])

tensor([[5.0000, 5.0000, 7.0711, 7.0711],
        [7.0711, 7.0711, 7.0711, 7.0711]])

In [127]:
good_box_loss = mse(good_pred[..., 1:5][obj], target[..., 1:5][obj])
bad_box_loss = mse(bad_pred[..., 1:5][obj], target[..., 1:5][obj])

good_sqrt_wh_box_loss = mse(good_pred[..., 1:3][obj], target[..., 1:3][obj]) + mse(torch.sqrt(good_pred[..., 3:5][obj]), torch.sqrt(target[..., 3:5][obj]))
bad_sqrt_wh_box_loss = mse(bad_pred[..., 1:3][obj], target[..., 1:3][obj]) + mse(torch.sqrt(bad_pred[..., 3:5][obj]), torch.sqrt(target[..., 3:5][obj]))

good_sqrt_box_loss = mse(torch.sqrt(good_pred[..., 1:5][obj]), torch.sqrt(target[..., 1:5][obj]))
bad_sqrt_box_loss = mse(torch.sqrt(bad_pred[..., 1:5][obj]), torch.sqrt(target[..., 1:5][obj]))

print(f"good box:            {good_box_loss}")
print(f"bad box:             {bad_box_loss}")
print(f"good sqrt(w,h) box:  {good_sqrt_wh_box_loss}")
print(f"bad sqrt(w,h) box:   {bad_sqrt_wh_box_loss}")
print(f"good sqrt(all) box:  {good_sqrt_box_loss}")
print(f"bad sqrt(all) box:   {bad_sqrt_box_loss}")

good box:            2.90625
bad box:             32381.375
good sqrt(w,h) box:  2.0817036628723145
bad sqrt(w,h) box:   15171.63671875
good sqrt(all) box:  0.016678646206855774
bad sqrt(all) box:   53.202247619628906


### Class Loss

In [102]:
cls_argmax = torch.argmax(target[..., 5:], dim=-1)

In [92]:
good_pred[..., 5:][obj]

tensor([[0.0100, 0.0100, 0.9800],
        [0.0200, 0.9800, 0.0000]])

In [96]:
bad_pred[..., 5:][obj]

tensor([[0.0100, 0.0100, 0.9800],
        [0.0200, 0.9800, 0.0000]])

In [106]:
good_class_loss = ce((good_pred[..., 5:][obj]), (cls_argmax[obj]))
bad_class_loss = ce((bad_pred[..., 5:][obj]), (cls_argmax[obj]))

print(f"good class: {good_class_loss}")
print(f"bad class:  {bad_class_loss}")

good class: 0.5642820596694946
bad class:  1.2538902759552002
