In [1]:
import os
import torch

save_dir = os.path.expanduser(os.environ["YOLO_MODELS"])
checkpoint_pth = max([f for f in os.listdir(save_dir) if f.endswith(".pth")])

In [2]:
from models import YOLOv1
from data.VOC_Dataset import VOC_Dataset
from data import DATA_HOME

from ipdb import set_trace
from torch.utils.data import DataLoader
from numpy import array
from multiprocessing import cpu_count
import random 
import torch
import pandas as pd

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x177feb54e50>

In [3]:
import platform

_voc_root = f"{DATA_HOME}/VOCdevkit/VOC2007"
voc_ds = VOC_Dataset(_voc_root)
BS = 1

def collate_fn(data):
    # output: 7 * 7 * 30
    # S * S * ((x, y, w, h, confidence) * B=2 + C=20)
    # make img batch and label batch
    imgs, labels, classes = zip(*data)
    # inhomegenous shape label, since each image has different number of objects
    # label's dimension: (Batch size, # of objects in each image, 4 coords)
    return torch.tensor(array(imgs), dtype=torch.float), labels, classes
    
if platform.system() == "Windows":
    loader = DataLoader(voc_ds, batch_size=BS, pin_memory=True, shuffle=True, collate_fn=collate_fn)
else:
    loader = DataLoader(voc_ds, batch_size=BS, pin_memory=True, shuffle=True, num_workers=4, collate_fn=collate_fn)

class dict:  {'bird': 0, 'sofa': 1, 'horse': 2, 'pottedplant': 3, 'bicycle': 4, 'motorbike': 5, 'boat': 6, 'aeroplane': 7, 'bottle': 8, 'bus': 9, 'person': 10, 'dog': 11, 'cow': 12, 'cat': 13, 'sheep': 14, 'car': 15, 'chair': 16, 'tvmonitor': 17, 'diningtable': 18, 'train': 19}


In [4]:
yolo = YOLOv1().to(device=device)

S = 7 # num of rows/cols
C = 20 # num of classes
B = 2 # num of bounding boxes

lamba_coord = 5
lamba_noobj = 0.5

In [16]:
from utils.display import display_image_bbox
from utils.metrics import xywh_2_xxyy
import numpy as np
 
for _, (batch, labels, classes) in enumerate(loader):
    res = yolo(batch.to(device=device)).squeeze(0)

    boxes = []
    for i in range(res.shape[1]):
        for j in range(res.shape[2]):
            cell1 = res[:, i, j][:5]
            cell2 = res[:, i, j][5:10]
            boxes.append(cell1.detach().cpu().numpy())
            boxes.append(cell2.detach().cpu().numpy())
            

    boxes = list(map(xywh_2_xxyy, boxes))

    boxes = [np.maximum(.0, np.array(b)) * 100 for b in boxes]
    boxes = [(min(b[0], b[2]), min(b[1], b[3]), b[2], b[3]) for b in boxes]
    display_image_bbox(batch[0], boxes)

KeyboardInterrupt: 

In [14]:
boxes

[(0.0, 0.0, 0.0, 0.0),
 (0.0, 0.0011984677985310555, 0.0, 0.016103041358292103),
 (0.0, 0.0013368073850870132, 0.0, 0.0013368073850870132),
 (0.0, 0.0, 0.0, 0.0005452549085021019),
 (0.0, 0.0003921561874449253, 0.0, 0.008811444509774446),
 (0.004697038501035422, 0.0, 0.005304809019435197, 0.0064894878305494785),
 (0.0, 0.0, 0.0, 0.0),
 (0.0, 0.0, 0.0, 0.0),
 (0.0, 0.0025238554226234555, 0.0, 0.0025238554226234555),
 (0.0, 0.00046626804396510124, 0.0, 0.00046626804396510124),
 (0.0, 0.0, 0.0, 0.0),
 (0.003984119510278106, 0.0, 0.003984119510278106, 0.0),
 (0.0, 0.009853772819042206, 0.0, 0.025569375604391098),
 (0.011115462519228458, 0.0, 0.014973520301282406, 0.0),
 (0.0, 0.009450451470911503, 0.0, 0.009450451470911503),
 (0.010699090780690312, 0.0, 0.010699090780690312, 0.0),
 (0.0, 0.0051350046414881945, 0.0, 0.0051350046414881945),
 (0.0, 0.0043311528861522675, 0.0, 0.020089993253350258),
 (0.0, 0.0, 0.0, 0.005611903732642531),
 (0.0, 0.009622039273381233, 0.0, 0.009622039273381233)