In [2]:
import json
from glob import glob
from os.path import join as pjoin
from tqdm import tqdm
import cv2

In [3]:
def draw_label(img, bound, color, text=None):
    cv2.rectangle(img, (bound[0], bound[1]), (bound[2], bound[3]), color, 3)
    if text is not None:
        # put text with rectangle 
        (w,h),_ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1.6, 3)
        cv2.rectangle(img, (bound[0], bound[1] - 40), (bound[0] + w, bound[1] - 40 + h), color, -1)
        cv2.putText(img, text, (bound[0], bound[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255,255,255), 3)
        

def visualize(image_name, compos, win_name='compos'):
    '''
    @compos: converted format: {'bboxes':[], 'labels':[]}
    '''
    img_file = pjoin(data_root, image_name + '.jpg')
    img = cv2.imread(img_file)
    img = cv2.resize(img, (1440, 2560))

    color = {'Text': (0,0,255), 'Non-Text': (0,255,0)}
    for i in range(len(compos['bboxes'])):
        bound = compos['bboxes'][i]
        label = compos['labels'][i]
        # if label in ('List Item', 'Toolbar', 'Modal', 'Multi-Tab', 'Drawer', 'Advertisement', 'Web View'):
        #     continue
        label = 'Text' if label == 'Text' else 'Non-Text'
        draw_label(img, bound, color[label], text=str(compos['labels'][i]))

    cv2.imshow(win_name, cv2.resize(img, (500, 800)))
    return cv2.waitKey()

In [4]:
def resize_result_bounds(bboxes, result_height, gt_height=2560):
    bboxes_new = []
    scale = gt_height / result_height
    for bbox in bboxes:
        bbox = [int(b * scale) for b in bbox]
        bboxes_new.append(bbox)
    return bboxes_new


def cvt_data_result(data):
    '''
    @ data: {'img_shape', 'compos':{'id', 'class', 'position':{'column_min','row_min','column_max','row_max'}}}
    @ return compos: {'bbox':[[left,top,right,bottom]], 'labels':[Text(1)/Non-Text(2)]}
    '''
    compos = {'bboxes': [], 'labels': []}
    for compo in data['compos']:
        position = compo['position']
        bound = [position['column_min'], position['row_min'], position['column_max'], position['row_max']]
        label = compo['class']
#         label = 1 if compo['class'] == 'Text' else 2
        compos['bboxes'].append(bound)
        compos['labels'].append(label)        
    return compos


def cvt_data_gt(data):
    compos = {'bboxes': [], 'labels': []}
    for i in range(len(data['labels'])):
        bound = data['bounds'][i]
        label = data['labels'][i]
        if label in ('List Item', 'Toolbar', 'Modal', 'Multi-Tab', 'Drawer', 'Advertisement', 'Web View'):
            continue
#         label = 1 if label == 'Text' else 2
        compos['bboxes'].append(bound)
        compos['labels'].append(label)
    return compos

In [5]:
result_root = 'E:\\Mulong\\Result\\rico-layout\\result\\uied'
data_root = 'E:\\Mulong\\Datasets\\gui\\rico\\combined\\all'

In [6]:
gt_data = json.load(open('valid_group_data.json'))

In [7]:
start_point = '2060'
start = False
for img_name in open('select.txt', 'r').readlines():
    img_name = img_name.replace('\n', '')
    if start_point is None or img_name == start_point: start = True
    if not start: continue
    print(img_name)
    
    detect_data = json.load(open(pjoin(result_root, img_name + '.json')))
    compos_detected = cvt_data_result(detect_data)
    compos_detected['bboxes'] = resize_result_bounds(compos_detected['bboxes'], detect_data['img_shape'][0]) 
    compos_gt = cvt_data_gt(gt_data[img_name])
    
    visualize(img_name, compos_gt, win_name='gt')
    key = visualize(img_name, compos_detected, win_name='detected')
    if key is not None and key != -1:
        if chr(key) == 'q':
            break
cv2.destroyAllWindows()

2060
37198
12456
26903
10282
27406
31517
30609
24668
10283
