In [None]:
import os, time, random
import torch
import cv2
import numpy as np
from models.yolov3quad import Darknet
from datasets.detection_dataset import load_images
from utils.metrics import non_max_suppression

In [None]:
INP_DIR = '../data/val_images/'
OUT_DIR = 'output_folder'
MODEL_CFG = './cfgs/yolov3.cfg'
WEIGHTS_PATH = './weights/best.pt'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 32*19
BATCH_SIZE = 1
CONF_THRES = 0.1
NMS_THRES = 0.2
TEXT_OUT = True
PLOT_FLAG = True

In [None]:
os.system('rm -rf '+OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)

model = Darknet(MODEL_CFG, img_size=32*19)
checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
model.load_state_dict(checkpoint['model'])
del checkpoint

model.to(DEVICE).eval()
classes = ['screen']
dataloader = load_images(INP_DIR, BATCH_SIZE, IMG_SIZE)

In [None]:
img_paths = []
img_detections = []
t0 = time.time()
for batch_idx, (imgs, paths) in enumerate(dataloader):
    with torch.no_grad():
        preds = model(imgs.to(DEVICE))
        preds = preds[preds[:,:,8]>CONF_THRES]
        if len(preds)>0:
            detections = non_max_suppression(preds.unsqueeze(0), 0.1, NMS_THRES)
            img_detections.extend(detections)
            img_paths.extend(paths)    
    print('Batch %d... (Done %.3f s)' % (batch_idx, time.time() - t0))
    t0 = time.time()

In [None]:
def plot_one_box(x, img, color=None, line_thickness=None):  # Plots one bounding box on image img
	tl = line_thickness or round(0.001 * max(img.shape[0:2])) + 1  # line thickness
	color = color or [random.randint(0, 255) for _ in range(3)]
	cv2.line(img, (int(x[0]), int(x[1])), (int(x[2]), int(x[3])), color, tl)
	cv2.line(img, (int(x[2]), int(x[3])), (int(x[4]), int(x[5])), color, tl)
	cv2.line(img, (int(x[4]), int(x[5])), (int(x[6]), int(x[7])), color, tl)
	cv2.line(img, (int(x[6]), int(x[7])), (int(x[0]), int(x[1])), color, tl)

In [None]:
# Bounding-box colors
color_list = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] 
	    for _ in range(len(classes))]
for idx, (path, detections) in enumerate(zip(img_paths, img_detections)):
	img = cv2.imread(path)
	# The amount of padding that was added
	pad_x = max(img.shape[0] - img.shape[1], 0) * (IMG_SIZE / max(img.shape))
	pad_y = max(img.shape[1] - img.shape[0], 0) * (IMG_SIZE / max(img.shape))
	# Image height and width after padding is removed
	unpad_h = IMG_SIZE - pad_y
	unpad_w = IMG_SIZE - pad_x
	# Draw bounding boxes and labels of detections
	if detections is not None:
		unique_classes = detections[:, -1].cpu().unique()
		bbox_colors = random.sample(color_list, len(unique_classes))

		# write results to .txt file
		results_img_path = os.path.join(OUT_DIR, path.split('/')[-1])
		results_txt_path = results_img_path + '.txt'
		if os.path.isfile(results_txt_path):
			os.remove(results_txt_path)

		for i in unique_classes:
			n = (detections[:, -1].cpu() == i).sum()
			print('%g %ss' % (n, classes[int(i)]))

		for P1_x, P1_y, P2_x, P2_y, P3_x, P3_y, P4_x, P4_y, conf, cls_conf, cls_pred in detections:
			P1_y = max((((P1_y - pad_y // 2) / unpad_h) * img.shape[0]).round().item(), 0)
			P1_x = max((((P1_x - pad_x // 2) / unpad_w) * img.shape[1]).round().item(), 0)
			P2_y = max((((P2_y - pad_y // 2) / unpad_h) * img.shape[0]).round().item(), 0)
			P2_x = max((((P2_x - pad_x // 2) / unpad_w) * img.shape[1]).round().item(), 0)
			P3_y = max((((P3_y - pad_y // 2) / unpad_h) * img.shape[0]).round().item(), 0)
			P3_x = max((((P3_x - pad_x // 2) / unpad_w) * img.shape[1]).round().item(), 0)
			P4_y = max((((P4_y - pad_y // 2) / unpad_h) * img.shape[0]).round().item(), 0)
			P4_x = max((((P4_x - pad_x // 2) / unpad_w) * img.shape[1]).round().item(), 0)

			# write to file
			if TEXT_OUT:
				with open(results_txt_path, 'a') as file:
					file.write(('%g %g %g %g %g %g %g %g %g %g \n') % (P1_x, P1_y, P2_x, P2_y, P3_x, P3_y, P4_x, P4_y, cls_pred, cls_conf * conf))
			
			if PLOT_FLAG:
				# Add the bbox to the plot
				label = '%s %.2f' % (classes[int(cls_pred)], conf)
				color = bbox_colors[int(np.where(unique_classes == int(cls_pred))[0])]
				plot_one_box([P1_x, P1_y, P2_x, P2_y, P3_x, P3_y, P4_x, P4_y], img, color=color)

		cv2.imshow(path.split('/')[-1], img)
		cv2.waitKey(0)
		cv2.destroyAllWindows()