## Associate Classes With DocTR Detections 

In [1]:
import glob
import cv2

In [14]:
gtpath = '../../docbank_processed/gt_cage_sample100/gt_cage_classwise/*.txt'
gtimages = '../../docbank_processed/processed_data/spear_ori_black/'
detpath = '../../docbank_processed/processed_data/out_txt_8/*.txt'
outdetpath = '../../docbank_processed/processed_data/out_txt_8_classwise/'
gtfiles = glob.glob(gtpath)
detfiles = glob.glob(detpath)
print(len(gtfiles))
print(len(detfiles))

100
100


In [15]:
def getbboxesfromgtfile(file):
    filename = file.split('/')[-1:][0]
    f = open(file, 'r')
    imgfile = gtimages + filename[:-4] + '_ori_pro.jpg'
    img = cv2.imread(imgfile)
    height, width, channel = img.shape
    result = []
    for lines in f:
        llist = lines.split(' ')
        x0 = int(llist[1])
        y0 = int(llist[2])
        #For cage gt, we will have to add x0 and y0
        x1 = x0 + int(llist[3])
        y1 = y0 + int(llist[4])
        label = llist[0]
        bbox = [x0, y0, x1, y1]
        bboxdata = [label, bbox]
        #print(bboxdata)
        result.append(bboxdata)
    return result

In [16]:
def getbboxesfromdettfile(file):
    f = open(file, 'r')
    result  = []
    i = 0
    for lines in f:
        if i == 0:
            i = 1
            continue
        llist = lines.split(' ')
        x0 = int(llist[2])
        y0 = int(llist[3])
        x1 = int(llist[4]) + x0
        y1 = int(llist[5]) + y0
        label = 'Unassigned'
        conf = float(llist[1])
        bbox = [x0, y0, x1, y1]
        bboxdata = [label, bbox, conf]
        #print(bboxdata)
        result.append(bboxdata)
    return result

In [17]:
def filterbboxesforclass(bboxdata, classname):
    result = []
    for b in bboxdata:
        if b[0] == classname:
            result.append(b)
    return result

In [18]:
def iou(boxA, boxB):
    # if boxes dont intersect
    if boxesIntersect(boxA, boxB) is False:
        return 0
    interArea = getIntersectionArea(boxA, boxB)
    union = getUnionAreas(boxA, boxB, interArea=interArea)
    # intersection over union
    iou = interArea / union
    assert iou >= 0
    return iou

# boxA = (Ax1,Ay1,Ax2,Ay2)
# boxB = (Bx1,By1,Bx2,By2)
def boxesIntersect(boxA, boxB):
    if boxA[0] > boxB[2]:
        return False  # boxA is right of boxB
    if boxB[0] > boxA[2]:
        return False  # boxA is left of boxB
    if boxA[3] < boxB[1]:
        return False  # boxA is above boxB
    if boxA[1] > boxB[3]:
        return False  # boxA is below boxB
    return True

def getArea(box):
    return (box[2] - box[0] + 1) * (box[3] - box[1] + 1)


def getIntersectionArea(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # intersection area
    return (xB - xA + 1) * (yB - yA + 1)


def getUnionAreas(boxA, boxB, interArea=None):
    area_A = getArea(boxA)
    area_B = getArea(boxB)
    if interArea is None:
        interArea = getIntersectionArea(boxA, boxB)
    return float(area_A + area_B - interArea)

In [19]:
def createdetfiles(file, detections):
    writeline = []
    for d in detections:
        line = d[0] + ' ' + str(d[2]) + ' ' + str(d[1][0]) + ' ' + str(d[1][1])  + ' ' + str(d[1][2] - d[1][0])  + ' ' + str(d[1][3] - d[1][1]) + '\n' 
        #print(line)
        writeline.append(line)
    f = open(file, 'w+')
    f.writelines(writeline)

In [20]:
for i in range(len(gtfiles)):
    filename = gtfiles[i].split('/')[-1:][0]
    groundtruths = getbboxesfromgtfile(gtfiles[i])
    detections = getbboxesfromdettfile(detfiles[i])
    gts = groundtruths
    #gts = filterbboxesforclass(groundtruths, classname)

    for d in detections:
        iou_max = 0
        for g in gts:
            detbox = d[1]
            gtbox = g[1]
            #print(detbox)
            #print(gtbox)
            iou_candidate = iou(gtbox, detbox)
            #if iou_candidate >= iou_thresh:
            if iou_candidate >= iou_max:
                iou_max = iou_candidate
                d[0] = g[0]
    outputfile = outdetpath + filename
    createdetfiles(outputfile, detections)