In [1]:
import json
from glob import glob
from os.path import join as pjoin
from tqdm import tqdm
import cv2
import numpy as np

In [2]:
def draw_label(img, bound, color, text=None):
    cv2.rectangle(img, (bound[0], bound[1]), (bound[2], bound[3]), color, 3)
    if text is not None:
        # put text with rectangle 
        (w,h),_ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1.6, 3)
        cv2.rectangle(img, (bound[0], bound[1] - 40), (bound[0] + w, bound[1] - 40 + h), color, -1)
        cv2.putText(img, text, (bound[0], bound[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255,255,255), 3)
        

def visualize(image_name, compos, win_name='compos'):
    '''
    @compos: converted format: {'bboxes':[], 'labels':[]}
    '''
    img_file = pjoin(data_root, image_name + '.jpg')
    img = cv2.imread(img_file)
    img = cv2.resize(img, (1440, 2560))

    color = {'Text': (0,0,255), 'Non-Text': (0,255,0)}
    for i in range(len(compos['bboxes'])):
        bound = compos['bboxes'][i]
        label = compos['labels'][i]
        # if label in ('List Item', 'Toolbar', 'Modal', 'Multi-Tab', 'Drawer', 'Advertisement', 'Web View'):
        #     continue
        label = 'Text' if label == 'Text' else 'Non-Text'
        draw_label(img, bound, color[label], text=str(compos['labels'][i]))

    cv2.imshow(win_name, cv2.resize(img, (500, 800)))
    return cv2.waitKey()

In [3]:
def resize_result_bounds(bboxes, result_height, gt_height=2560):
    bboxes_new = []
    scale = gt_height / result_height
    for bbox in bboxes:
        bbox = [int(b * scale) for b in bbox]
        bboxes_new.append(bbox)
    return bboxes_new


def cvt_data_result(data):
    '''
    @ data: {'img_shape', 'compos':{'id', 'class', 'position':{'column_min','row_min','column_max','row_max'}}}
    @ return compos: {'bbox':[[left,top,right,bottom]], 'labels':[Text(1)/Non-Text(2)]}
    '''
    compos = {'bboxes': [], 'labels': []}
    for compo in data['compos']:
        position = compo['position']
        bound = [position['column_min'], position['row_min'], position['column_max'], position['row_max']]
        label = compo['class']
#         label = 1 if compo['class'] == 'Text' else 2
        compos['bboxes'].append(bound)
        compos['labels'].append(label)        
    return compos


def cvt_data_gt(data):
    compos = {'bboxes': [], 'labels': []}
    for i in range(len(data['labels'])):
        bound = data['bounds'][i]
        label = data['labels'][i]
        if label in ('List Item', 'Toolbar', 'Modal', 'Multi-Tab', 'Drawer', 'Advertisement', 'Web View'):
            continue
#         label = 1 if label == 'Text' else 2
        compos['bboxes'].append(bound)
        compos['labels'].append(label)
    return compos

In [4]:
def calc_iou(b1, b2):
    '''
    bbox:[left,top,right,bottom]
    '''
    left = max(b1[0], b2[0])
    top = max(b1[1], b2[1])
    right = min(b1[2], b2[2])
    bottom = min(b1[3], b2[3])
    width = max(0, right - left)
    height = max(0, bottom - top)
   
    area1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
    area2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
    area_inter = width * height
    
    iou = area_inter / (area1 + area2)
    ioa = area_inter / area1
    iob = area_inter / area2
    return iou, ioa, iob


def evaluate(compos_detected, compos_gt, iou_thresh=0.9):
    TP, FP, FN = 0, 0, len(compos_gt['bboxes'])
    marked = np.full(len(compos_gt['bboxes']), False)
    for i, det in enumerate(compos_detected['bboxes']):
        matched = False
        for j, gt in enumerate(compos_gt['bboxes']):
            if marked[j]: continue
            iou, ioa, iob = calc_iou(det, gt)
            if iou > iou_thresh or ioa > iou_thresh: 
                TP += 1
                FN -= 1
                matched = True
                marked[j] = True
                break
        if not matched:
            FP += 1
    return TP, FP, FN


def evaluate_text(compos_detected, compos_gt, iou_thresh=0.9):
    bboxes = []
    for i, label in enumerate(compos_gt['labels']):
        if 'Text' in label:
            bboxes.append(compos_gt['bboxes'][i])
    
    TP, FP, FN = 0, 0, len(bboxes)
    marked = np.full(len(bboxes), False)
    for i, det in enumerate(compos_detected['bboxes']):
        if compos_detected['labels'][i] != 'Text': continue
        matched = False
        for j, gt in enumerate(bboxes):
            iou, ioa, iob = calc_iou(det, gt)
            if iou > iou_thresh or ioa > iou_thresh: 
                TP += 1
                FN -= 1
                matched = True
                marked[j] = True
                break
        if not matched:
            FP += 1
    return TP, FP, FN


def evaluate_nontext(compos_detected, compos_gt, iou_thresh=0.9):
    bboxes = []
    for i, label in enumerate(compos_gt['labels']):
        if 'Text' not in label:
            bboxes.append(compos_gt['bboxes'][i])
    
    TP, FP, FN = 0, 0, len(bboxes)
    marked = np.full(len(bboxes), False)
    for i, det in enumerate(compos_detected['bboxes']):
        if compos_detected['labels'][i] == 'Text': continue
        matched = False
        for j, gt in enumerate(bboxes):
            iou, ioa, iob = calc_iou(det, gt)
            if iou > iou_thresh or ioa > iou_thresh: 
                TP += 1
                FN -= 1
                matched = True
                marked[j] = True
                break
        if not matched:
            FP += 1
    return TP, FP, FN


def evaluate_class(compos_detected, compos_gt, iou_thresh=0.9):
    det_labels = [1 if label == 'Text' else 2 for label in compos_detected['labels']]
    gt_labels = [1 if 'Text' in label else 2 for label in compos_gt['labels']]
    
    TP, FP, FN = 0, 0, len(compos_gt['bboxes'])
    marked = np.full(len(compos_gt['bboxes']), False)
    for i, det in enumerate(compos_detected['bboxes']):
        matched = False
        for j, gt in enumerate(compos_gt['bboxes']):
            if marked[j]: continue
            if det_labels[i] != gt_labels[j]: continue
            iou, ioa, iob = calc_iou(det, gt)
            if iou > iou_thresh or ioa > iou_thresh: 
                TP += 1
                FN -= 1
                matched = True
                marked[j] = True
                break
        if not matched:
            FP += 1
    return TP, FP, FN

In [5]:
result_root = 'E:\\Mulong\\Result\\rico-layout\\result\\uied'
data_root = 'E:\\Mulong\\Datasets\\gui\\rico\\combined\\all'
gt_data = json.load(open('valid_group_data.json'))

In [10]:
TP, FP, FN = 0, 0, 0

start_point = None
start = False
for img_name in tqdm(open('select.txt', 'r').readlines()):
    img_name = img_name.replace('\n', '')
    if start_point is None or img_name == start_point: start = True
    if not start: continue
    
    detect_data = json.load(open(pjoin(result_root, img_name + '.json')))
    compos_detected = cvt_data_result(detect_data)
    compos_detected['bboxes'] = resize_result_bounds(compos_detected['bboxes'], detect_data['img_shape'][0]) 
    compos_gt = cvt_data_gt(gt_data[img_name])
    
#     tp, fp, fn = evaluate(compos_detected, compos_gt, iou_thresh=0.9)
    tp, fp, fn = evaluate_class(compos_detected, compos_gt, iou_thresh=0.9)
#     tp, fp, fn = evaluate_text(compos_detected, compos_gt, iou_thresh=0.9)
#     tp, fp, fn = evaluate_nontext(compos_detected, compos_gt, iou_thresh=0.9)
    TP, FP, FN = TP + tp, FP + fp, FN + fn 
    
#     print(img_name)
#     print(tp, fp, fn)
#     visualize(img_name, compos_gt, win_name='gt')
#     key = visualize(img_name, compos_detected, win_name='detected')
#     if key is not None and key != -1:
#         if chr(key) == 'q':
#             break
# cv2.destroyAllWindows()

pre = TP / (TP+ FP)
recall = TP / (TP + FN)
f1 = (2 * pre * recall) / (pre + recall)
print('TP:%d, FP:%d, FN:%d, Total:%d' % (TP, FP, FN, (TP + FN)))
print('Precision:%.3f, Recall:%.3f, F1:%.3f' %(pre, recall, f1))

100%|████████████████████████████████████████████████████████████████████████████| 1091/1091 [00:00<00:00, 1208.51it/s]

TP:15953, FP:10175, FN:6340, Total:22293
Precision:0.611, Recall:0.716, F1:0.659





In [None]:
# text
TP:9275, FP:4397, FN:4112, Total:13387
Precision:0.678, Recall:0.693, F1:0.686

In [None]:
# nontext
TP:7334, FP:5122, FN:1572, Total:8906
Precision:0.589, Recall:0.823, F1:0.687

In [None]:
# all
TP:15164, FP:10964, FN:7129, Total:22293
Precision:0.580, Recall:0.680, F1:0.626

In [42]:
img = cv2.imread(pjoin(data_root, img_name + '.jpg'))
img = cv2.resize(img, (1440, 2560))

TP, FP, FN = 0, 0, len(compos_gt['bboxes'])
marked = np.full(len(compos_gt['bboxes']), False)
for i, det in enumerate(compos_detected['bboxes']):
    board = img.copy()
    draw_label(board, det, (0, 255, 0))
    
    matched = False
    for j, gt in enumerate(compos_gt['bboxes']):
        if marked[j]: continue
        iou, ioa, iob = calc_iou(det, gt)
        if iou > 0.1 or ioa == 1: 
            TP += 1
            FN -= 1
            matched = True
            marked[j] = True
            
            print(iou, ioa, iob)    
            draw_label(board, gt, (0, 0, 255))        
            cv2.imshow('iou', cv2.resize(board, (500, 800)))
            cv2.waitKey()
            break
    if not matched:
        FP += 1
cv2.destroyAllWindows()

0.07142857142857142 1.0 0.07692307692307693
0.49193840001701655 0.9741791453423475 0.9937694704049844
0.48223401566843566 0.9902439024390244 0.94
0.34578374860542954 0.7856049004594181 0.6176360692489725
0.1073578184498219 1.0 0.12026971239851383
0.09458562488147165 1.0 0.10446667015761638
0.1401076616768675 1.0 0.16293628333762114
0.13626100502336652 1.0 0.15775715327875497
0.11824239376414383 1.0 0.13409852427461325
0.39146018388655135 1.0 0.6432778489116517
0.22750917692322792 0.488 0.42621262824991724
0.41056603773584904 1.0 0.6965428937259923
0.21760456612860074 0.4941579685309238 0.38882555559641574
0.38065027755749403 1.0 0.614596670934699
0.24257954302797352 0.5138157894736842 0.4595300261096606
0.415349832255555 0.9970238095238095 0.711934327860655
