In [None]:
from models import YOLOv1
from data.VOC_Dataset import VOC_Dataset
from data import DATA_HOME

from ipdb import set_trace
from torch.utils.data import DataLoader
from numpy import array
from multiprocessing import cpu_count
import random 
import torch
import pandas as pd

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

random.seed(1)
torch.manual_seed(1)

In [None]:
import platform

_voc_root = f"{DATA_HOME}/VOCdevkit/VOC2007"
voc_ds = VOC_Dataset(_voc_root)
BS = 2

def collate_fn(data):
    # output: 7 * 7 * 30
    # S * S * ((x, y, w, h, confidence) * B=2 + C=20)
    # make img batch and label batch
    imgs, labels, classes = zip(*data)
    # inhomegenous shape label, since each image has different number of objects
    # label's dimension: (Batch size, # of objects in each image, 4 coords)
    return torch.tensor(array(imgs), dtype=torch.float), labels, classes
    
if platform.system() == "Windows":
    loader = DataLoader(voc_ds, batch_size=BS, pin_memory=True, shuffle=True, collate_fn=collate_fn)
else:
    loader = DataLoader(voc_ds, batch_size=BS, pin_memory=True, shuffle=True, num_workers=4, collate_fn=collate_fn)

In [None]:

yolo = YOLOv1().to(device=device)

S = 7 # num of rows/cols
C = 20 # num of classes
B = 2 # num of bounding boxes

lamba_coord = 5
lamba_noobj = 0.5

In [None]:
from utils.metrics import IOU
from collections import defaultdict as dd

def yolo_loss(res_mat: torch.tensor, label_mat: list, class_mat: list, loss_df: pd.DataFrame):
    """calcalate batch yolo loss, @param res_mat: (batch_size, B*5+C, S, S)"""
    loss1 = torch.tensor(0.).to(device=device)
    loss2 = torch.tensor(0.).to(device=device)
    loss3 = torch.tensor(0.).to(device=device)
    loss4 = torch.tensor(0.).to(device=device)
    loss5 = torch.tensor(0.).to(device=device)
    # calculate loss for every bounding box in every cell
    for b, batch in enumerate(res_mat):
        # assign labels bbox to cell indices
        stride = 1. / S
        label_inds = dd(list)
        label_class = {}
        for idx, (x, y, w, h) in enumerate(label_mat[b]):
            xi, yi = int(x // stride), int(y // stride)
            label_inds[(xi, yi)].append((x, y, w, h))
            label_class[(x, y, w, h)] = class_mat[b][idx]
        
        # print("label class", label_class, b, batch.shape)
        # iterate cell and calculate loss
        for i in range(batch.shape[1]):
            for j in range(batch.shape[2]):
                cell = batch[:, i, j]

                if (i, j) in label_inds:
                    yprobs = torch.tensor([0.] * C).to(device=device)
                    for x, y, w, h in label_inds[i, j]:
                        for k in range(0, B*5, 5):
                            x_, y_, w_, h_, c_ = cell[k:k+5]
                            loss1 += lamba_coord * ((x-x_) ** 2 + (y-y_) ** 2) # yolo loss term 1
                            loss2 += lamba_coord * ((w ** 0.5 - abs(w_) ** 0.5)**2 + (h ** 0.5 - abs(h_) ** 0.5)**2) # yolo loss term 2
                            loss3 += (IOU((x, y, w, h), (x_, y_, w_, h_)) - c_) ** 2 # yolo loss term 3
                        yprobs[label_class[x, y, w, h]] = 1.  
                    loss5 += ((yprobs - cell[-C:]) ** 2).sum() # yolo loss term 5
                else:
                    for k in range(4, B*5, 5):
                        _c = cell[k]
                        loss4 += _c ** 2 * lamba_noobj # yolo loss term 4
        # print("xywh: " , x_, y_, w_, h_, c_)
        # print("label: ", x, y, w, h)
    loss_df = pd.concat((loss_df, pd.DataFrame([[loss1.item(), loss2.item(), loss3.item(), loss4.item(), loss5.item()]], columns=loss_df.columns)), ignore_index=True)
    return loss1 + loss2 + loss3 + loss4 + loss5, loss_df
    # print("ret: ", res_mat.shape)
    

In [None]:
from sys import modules
import os
import torch.optim as optim
import matplotlib.pyplot as plt

save_dir = os.path.expanduser(os.environ["YOLO_MODELS"])

optimizer = optim.Adam(yolo.parameters(), lr=1e-6)
torch.cuda.empty_cache()
loss_df = pd.DataFrame(columns=["l1", "l2", "l3", "l4", "l5"])
for epoch in range(1):
    for _id, sample in enumerate(loader):
        # forward-propagate
        batch, labels, classes = sample
        res = yolo(batch.to(device=device))
        # print(res.shape, len(labels[2]), classes[2])
        loss, loss_df = yolo_loss(res, labels, classes, loss_df)
        # set_trace()

        # back-propagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # record
        if _id % 2000 == 0:
            print(f"Epoch {epoch}, iteration: {_id}, loss: {loss}")
            loss_df.to_csv(f"{save_dir}/loss_latest.csv")
            # print("loss_df: ", loss_df)
            # for col in loss_df.columns:
            #     plt.plot(loss_df[col], label=col)
            # plt.xlabel("iteration")
            # plt.ylabel("values")
            # plt.legend()
            # plt.grid(True)
            # plt.show()

            checkpoint = {
                'model': yolo.state_dict(),
                'optimizer': optimizer.state_dict(),
                'iteration': _id,
                'train_loss': loss,
                # 'val_loss': val_loss
            }
            torch.save(checkpoint, f"{save_dir}/yolov1_{epoch}_{_id}.pth")
            
        # showing the image with labels
        # res_img = Image.fromarray((batch[0] * 255).permute(1, 2, 0).byte().numpy())
        # draw = ImageDraw.Draw(res_img)
        # for pc in labels[0]:
            # draw.rectangle((448*pc[0], 448*pc[1], 448*pc[2], 448*pc[3]), outline="red")
        # res_img.show()

    
    