In [None]:
import sys
sys.path.insert(0, "..")
from ipdb import set_trace

from data import DATA_HOME

In [None]:
from yolo.models import YOLOv1
from yolo.data.VOC_Dataset import VOC_Dataset
from utils.img_process import show_bbox
from torch.utils.data import DataLoader
from numpy import array
import torch

# output: 7 * 7 * 30
# S * S * ((x, y, w, h, confidence) * B=2 + C=20)
_voc_root = f"{DATA_HOME}/VOCdevkit/VOC2007/"
voc_ds = VOC_Dataset(_voc_root)

def collate_fn(data):
    # make img batch and label batch
    imgs, labels, classes = zip(*data)
    # inhomegenous shape label, since each image has different number of objects
    # label's dimension: (Batch size, # of objects in each image, 4 coords)
    return torch.tensor(array(imgs), dtype=torch.float), labels, classes
    
loader = DataLoader(voc_ds, batch_size=4, pin_memory=True, shuffle=True, num_workers=4, collate_fn=collate_fn)


In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
test = torch.randn(3, 448, 448).unsqueeze(0).to(device=device)
yolo = YOLOv1().to(device=device)

S = 7 # num of rows/cols
C = 20 # num of classes
B = 2 # num of bounding boxes

lamba_coord = 5
lamba_noobj = 0.5


def IOU(xywh, xyxy):
    """Calculate the intersection over union of two sets rectangles

    Keyword arguments
    coords -- output rect, with pct coordinates (x, y, w, h)!!!
    y_coords -- label rect, with pct coordinates (xmin, ymin, xmax, ymax)"""
    xmin, ymin, xmax, ymax = (
        xywh[0] - xywh[2] / 2,
        xywh[1] - xywh[3] / 2,
        xywh[0] + xywh[2] / 2,
        xywh[1] + xywh[3] / 2,
    )

    x_inter = min(xmax, xyxy[2]) - max(xmin, xyxy[0])
    y_inter = min(ymax, xyxy[3]) - max(ymin, xyxy[1])

    if x_inter <= 0.0 or y_inter <= 0.0:
        return 0.0

    intersection = x_inter * y_inter

    overlapped_union = (xmax - xmin) * (ymax - ymin) + (xyxy[2] - xyxy[0]) * (
        xyxy[3] - xyxy[1]
    )
    
    return intersection / (overlapped_union - intersection)

# sanity checks of IOU
coords = (0.25, 0.25, 0.5, 0.5)
y_coords1 = (0.25, 0.4, 0.75, 0.75)
y_coords2 = (0.4, 0.25, 0.75, 0.75)
y_coords3 = (0, 0, 0.5, 0.5)
y_coords4 = (0, 0, 0.5, 0)
y_coords5 = (0.5, 0, 0.75, 0.75)
y_coords6 = (0.1, 0.1, 0.3, 0.4)

def float_eqs(a, b, decimal_pt):
    eps = 10 ** (-decimal_pt)
    return abs(a-b) < eps

assert float_eqs(IOU(coords, y_coords1), 0.025 / (0.5*0.5 + 0.5*0.35 - 0.025), 5)
assert float_eqs(IOU(coords, y_coords2), 0.025 / (0.5*0.5 + 0.5*0.35 - 0.025), 5)
assert float_eqs(IOU(coords, y_coords3), 1, 5)
assert float_eqs(IOU(coords, y_coords4), 0, 5)
assert float_eqs(IOU(coords, y_coords5), 0, 5)
assert float_eqs(IOU(coords, y_coords6), 0.06 / (0.5*0.5), 5)

In [None]:
from collections import defaultdict as dd

def yolo_cell_loss(i, j, yhat, bbox_y, class_y):
    
    pass

def yolo_loss(res_mat: torch.tensor, label_mat: list, class_mat: list):
    """calcalate batch yolo loss, @param res_mat: (batch_size, B*5+C, S, S)"""
    loss = torch.tensor(0.).to(device=device)
    
    # calculate loss for every bounding box in every cell
    for b, batch in enumerate(res_mat):
        # assign labels bbox to cell indices
        stride = 1. / S
        label_inds = dd(list)
        label_class = {}
        for idx, (x, y, w, h) in enumerate(label_mat[b]):
            xi, yi = int(x // stride), int(y // stride)
            label_inds[(xi, yi)].append((x, y, w, h))
            label_class[(x, y, w, h)] = class_mat[b][idx]
        
        # iterate cell and calculate loss
        for i in range(batch.shape[1]):
            for j in range(batch.shape[2]):
                cell = batch[i, j]
                for x, y, w, h in label_inds[i, j]:
                    for k in range(0, B*2, 5):
                        x_, y_, w_, h_, c_ = cell[k:k+5]
                        loss += (x-x_) ** 2 + (y-y_) ** 2 # yolo loss term 1
                        loss += (w ** 0.5 - w_ ** 0.5)**2 + (h ** 0.5 - h_ ** 0.5)**2 # yolo loss term 2
                        # c = label_class[(x, y, w, h)]
                        # c_label_embed = torch.tensor([0.] * B)
                        # c_label_embed[c] = 1
                        # c_output_embed = torch.tensor([0.] * B)
                        # c_output_embed[c] = 1
                        # loss += c_embed[]
                        print("labels:k ", x, y, w, h)
                        print("output: ", x_, y_, w_, h_, c_)
                break
            break
    # print("ret: ", res_mat.shape)
    

In [None]:
from PIL import Image, ImageDraw

cnt = 0
for _id, sample in enumerate(loader):
    batch, labels, classes = sample
    
    print(batch.shape, len(labels[1]), classes[1])
    res = yolo(batch.to(device=device))
    yolo_loss(res, labels, classes)

    # showing the image with labels
    # set_trace()
    res_img = Image.fromarray((batch[0] * 255).permute(1, 2, 0).byte().numpy())
    draw = ImageDraw.Draw(res_img)
    for pc in labels[0]:
        draw.rectangle((448*pc[0], 448*pc[1], 448*pc[2], 448*pc[3]), outline="red")
    res_img.show()

    break
    