In [None]:
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes, save_image
import torch
import os

In [None]:
def denorm_xywh(xywh, W, H):
    x,y,w_,h_ = xywh
    x1 = (x - w_/2)*W; y1 = (y - h_/2)*H
    x2 = (x + w_/2)*W; y2 = (y + h_/2)*H
    return [int(x1), int(y1), int(x2), int(y2)]

In [None]:
def load_gt(path, W, H, class_names, valid_cls):
    boxes, labels = [], []
    if not os.path.exists(path):
        return boxes, labels
    for line in open(path):
        cls, *xywh = line.strip().split()
        cls = int(cls)
        if cls not in valid_cls: 
            continue
        box = denorm_xywh(list(map(float, xywh)), W, H)
        boxes.append(box)
        labels.append(class_names[cls])
    return boxes, labels

In [None]:
def get_preds(res, class_names, valid_cls):
    boxes, labels = [], []
    for x1, y1, x2, y2, conf, cls in res:
        cls = int(cls)
        if cls not in valid_cls: 
            continue
        boxes.append([int(x1),int(y1),int(x2),int(y2)])
        labels.append(class_names[cls])
    return boxes, labels

In [None]:
def create_output(images_path, labels_path, model, img_extension=".png"):
    # assume both models share the same names dict
    class_names = model.model.names
    valid_cls = set(class_names.keys())

    cwd = os.getcwd()
    gt_images_path = os.path.join(cwd, "gt_images")
    pred_images_path = os.path.join(cwd, "pred_images")
    os.mkdir(gt_images_path)
    os.mkdir(pred_images_path)

    for filename in os.listdir(images_path):
        img_path = os.path.join(images_path, filename)
        img_tensor = read_image(img_path)

        boxes, labels = load_gt(os.path.join(labels_path, filename.replace(img_extension, ".txt")), 1280, 720, class_names, valid_cls)
        result_img = draw_bounding_boxes(img_tensor, boxes=torch.tensor(boxes), labels=labels, colors="blue", width=3)
        save_image(result_img.float()/255, os.path.join(gt_images_path, filename))

        res = model(os.path.join(images_path, filename), imgsz=1280, verbose=False)[0].boxes.data.tolist()
        boxes, labels = get_preds(res, class_names, valid_cls)
        result_img = draw_bounding_boxes(img_tensor, boxes=torch.tensor(boxes), labels=labels, colors="red", width=3)
        save_image(result_img.float()/255, os.path.join(pred_images_path, filename))