In [1]:
import json
import os
from os.path import join as pjoin
import cv2

import sys
from tqdm import tqdm
import numpy as np

In [2]:
def draw_label(img, bound, color, text=None, put_text=True):
    cv2.rectangle(img, (bound[0], bound[1]), (bound[2], bound[3]), color, 3)
    if put_text:
        # put text with rectangle 
        (w,h),_ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1.6, 3)
        cv2.rectangle(img, (bound[0], bound[1] - 40), (bound[0] + w, bound[1] - 40 + h), color, -1)
        cv2.putText(img, text, (bound[0], bound[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255,255,255), 3) 


def visualize_list(img, data, show=True):
    board = img.copy()
    is_drawn = False
    
#     if data['vertical_groups'] is not None:
#         for g in data['vertical_groups']:
#             draw_label(board, data['bounds'][g[0]], (0,0,166), 'vertical_groups')
#             for compo_id in g[1:]:
#                 b = data['bounds'][compo_id]
#                 draw_label(board, b, (0,0,166), put_text=False)
#                 is_drawn = True

#     if data['horizontal_groups'] is not None:
#         for g in data['horizontal_groups']:
#             draw_label(board, data['bounds'][g[0]], (166,0,0), 'horizontal_groups')
#             for compo_id in g[1:]:
#                 b = data['bounds'][compo_id]
#                 draw_label(board, b, (166,0,0), put_text=False)
#                 is_drawn = True
                
    if data['multitab_groups'] is not None:
        for g in data['multitab_groups']:
            draw_label(board, data['bounds'][g[0]], (166,0,166), 'multitab_groups')
            for compo_id in g[1:]:
                b = data['bounds'][compo_id]
                draw_label(board, b, (166,0,166), put_text=False)
                is_drawn = True
    
    if data['list_groups'] is not None:
        for g in data['list_groups']:
            draw_label(board, data['bounds'][g[0]], (166,100,255), 'list_groups')
            for compo_id in g[1:]:
                b = data['bounds'][compo_id]
                draw_label(board, b, (166,100,255), put_text=False)
                is_drawn = True        
                
    if data['pageindcator_groups'] is not None:
        for g in data['pageindcator_groups']:
            draw_label(board, data['bounds'][g[0]], (166,166,166), 'pageindcator_groups')
            for compo_id in g[1:]:
                b = data['bounds'][compo_id]
                draw_label(board, b, (166,166,166), put_text=False)
                is_drawn = True     
        
    if show and is_drawn:
        cv2.imshow('data', cv2.resize(board, (500, 800)))
        cv2.imshow('org', cv2.resize(img, (500, 800)))
        key = cv2.waitKey(0)
        return key
    return None


def visualize_results(result_root, image_name):
    uied_result = cv2.imread(pjoin(result_root, 'uied', image_name + '.jpg'))
    layout_result = cv2.imread(pjoin(result_root, 'layout', image_name + '-list.jpg'))
    cv2.imshow('uied', cv2.resize(uied_result, (500, 800)))
    cv2.imshow('layout', cv2.resize(layout_result, (500, 800)))
    

def visualize_ground_truth(image_name, gt_data):
    img_file = pjoin(data_root, image_name + '.jpg')
    img_data = gt_data[image_name]
    img = cv2.imread(img_file)
    
    org_json_file = pjoin(data_root, image_name + '.json')
    org_data = json.load(open(org_json_file))
    org_shape = org_data['activity']['root']['bounds']
    
    img = cv2.resize(img, (1440, 2560))
    return visualize_list(img, img_data)

In [3]:
def print_groups(gt, result):
    print('*** Ground Truth Number of groups:%d ***' % len(gt))
    for group in gt:
        print('Number of items:', len(group))
        for item in group:
            print(item)
        
    print('*** Result Number of groups:%d***' % len(result))
    for group in result:
        print('Number of items:', len(group))
        for item in group:
            print(item)


def cvt_gt_data(gt_data):
    gt_groups = []
    lst = sorted(gt_data['list'].items(), key=lambda x: x[0])
    lst_new = []
    for l in lst:
        if len(l[1]) > 0:
             lst_new.append(l[1])
    if len(lst_new) > 0: gt_groups.append(lst_new)
        
    multitab = sorted(gt_data['multitab'].items(), key=lambda x: x[0])
    mb_new = []
    for m in multitab:
#         if len(m[1]) > 0:
        mb_new.append(m[1])
    if len(mb_new) > 0: gt_groups.append(mb_new)
    return gt_groups


def calc_matched_items_number(g1, g2):
    matched_item_num = 0
    # element number each item
    matched_item_id = 0
    for i in range(len(g1)):
        for j in range(matched_item_id, len(g2)):
            # compare the element number in the two list items
            if abs(len(g1[i]) - len(g2[j])) <= 2:
                matched_item_num += 1
                matched_item_id = j + 1
    return matched_item_num


def evaluate_gui(gt_groups, result_groups, thresh=1):
    correct_list_num = 0
    correct_item_num = 0
    
    tp, fp, fn = 0, 0, len(gt_groups)
    marked = np.full(len(gt_groups), False)
    for res_group in result_groups:
        matched = False
        for j, gt_group in enumerate(gt_groups):
            if marked[j]: continue
            matched_items_num = calc_matched_items_number(gt_group, res_group)
            if len(gt_group) - matched_items_num <= thresh:
                tp += 1
                fn -= 1
                matched = True
                marked[j] = True
                break
        if not matched:
            fp += 1
    return tp, fp, fn

In [32]:
data_root = 'E:\\Mulong\\Datasets\\gui\\rico\\combined\\all'
result_root_det = 'E:\\Mulong\\Result\\rico-layout\\result'
result_root_gt_compo = 'E:\\Mulong\\Result\\rico-layout\\result-gt-compos-resize'
group_data_gt = json.load(open('item_data_class.json'))
gui_gt = json.load(open('valid_group_data.json'))

In [5]:
def inspect_results(result_root, start_point = None, thresh=1):
    start = False
    for img_name in open('select.txt', 'r').readlines():
        img_name = img_name.replace('\n', '')
        if start_point is None or img_name == start_point: start = True
        if not start: continue
        print('\n', img_name)

        group_gt = cvt_gt_data(group_data_gt[img_name])
        group_result = json.load(open(pjoin(result_root, 'layout', img_name + '-list.json')))['list']
        print('TP:%d FP:%d FN:%d' % evaluate_gui(group_gt, group_result, thresh=thresh))

        print_groups(group_gt, group_result)
        visualize_results(result_root, img_name)
        key = visualize_ground_truth(img_name, gui_gt)
        if key is not None and key != -1 and chr(key) == 'q':
            break
    cv2.destroyAllWindows()

In [None]:
inspect_results(result_root_gt_compo, None)


 2114
TP:0 FP:1 FN:1
*** Ground Truth Number of groups:1 ***
Number of items: 10
['TextView', 'TextView']
['TextView', 'TextView']
['TextView', 'TextView', 'TextView']
['TextView', 'TextView']
['TextView']
['TextView', 'TextView']
['TextView']
['TextView', 'TextView']
['TextView', 'TextView']
['TextView', 'TextView']
*** Result Number of groups:1***
Number of items: 8
['Compo', 'Text']
['Compo', 'Text']
['Compo', 'Text']
['Compo', 'Text']
['Compo', 'Text']
['Text', 'Compo']
['Text', 'Compo']
['Text', 'Compo']

 30783
TP:1 FP:0 FN:0
*** Ground Truth Number of groups:1 ***
Number of items: 8
['TextView']
['TextView']
['TextView']
['TextView']
['TextView']
['TextView']
['TextView']
['TextView', 'TextView']
*** Result Number of groups:1***
Number of items: 8
['Text']
['Text']
['Text']
['Text']
['Text']
['Text']
['Text']
['Text']

 41634
TP:0 FP:1 FN:1
*** Ground Truth Number of groups:1 ***
Number of items: 8
['View', 'TextView']
['View', 'TextView']
['View', 'TextView']
['View', 'TextVie

In [None]:
1176

In [8]:
TP, FP, FN = 102, 890, 1120
pre = TP / (TP+ FP)
recall = TP / (TP + FN)
f1 = (2 * pre * recall) / (pre + recall)

In [None]:
print(pre, recall, f1)

In [10]:
TP, FP, FN = 136, 902, 1040
pre = TP / (TP+ FP)
recall = TP / (TP + FN)
f1 = (2 * pre * recall) / (pre + recall)

In [11]:
print(pre, recall, f1)

0.13102119460500963 0.11564625850340136 0.12285456187895212


In [7]:
def evaluate_all(result_root, thresh=1):
    TP, FP, FN = 0, 0, 0
    for img_name in tqdm(open('select.txt', 'r').readlines()):
        img_name = img_name.replace('\n', '')

        group_gt = cvt_gt_data(group_data_gt[img_name])
        group_result = json.load(open(pjoin(result_root, 'layout', img_name + '-list.json')))['list']

        tp, fp, fn = evaluate_gui(group_gt, group_result, thresh=thresh)
        TP, FP, FN = TP + tp, FP + fp, FN + fn 
    #     print(img_name, tp, fp, fn)

    pre = TP / (TP+ FP)
    recall = TP / (TP + FN)
    f1 = (2 * pre * recall) / (pre + recall)
    print('TP:%d, FP:%d, FN:%d, Total:%d' % (TP, FP, FN, (TP + FN)))
    print('Precision:%.3f, Recall:%.3f, F1:%.3f' %(pre, recall, f1))

In [30]:
t = 4

In [31]:
evaluate_all(result_root_gt_compo, t)

100%|████████████████████████████████████████████████████████████████████████████| 1091/1091 [00:00<00:00, 3064.14it/s]

TP:1030, FP:477, FN:149, Total:1179
Precision:0.683, Recall:0.874, F1:0.767





In [29]:
evaluate_all(result_root_det, t)

100%|████████████████████████████████████████████████████████████████████████████| 1091/1091 [00:00<00:00, 2933.16it/s]

TP:613, FP:790, FN:566, Total:1179
Precision:0.437, Recall:0.520, F1:0.475



