In [17]:
from PIL import Image, ImageOps
import numpy as np
import time
from copy import deepcopy
from difflib import SequenceMatcher
from paddleocr import PaddleOCR
ocr_model = PaddleOCR(lang='en',use_angle_cls=True,show_log=False)

In [283]:
def run_ocr_merger(image,template_data_loc,template_size):
    table_box = (template_data_loc["table_data"][0][0],template_data_loc["table_data"][0][1],
             template_data_loc["table_data"][1][0],template_data_loc["table_data"][2][1])
    
    #Perform image resizing
    image = ImageOps.exif_transpose(image)
    img = np.asarray(image)
    raw = ocr_model.ocr(img, cls=True)
    raw = order_by_tbyx(raw)[0]
    
    #Extract data requested from template and apply rotation correction
    image = unrotate(raw,image,template_size)
    image.show()
    
#     table_data,extracted_data = extract_requested_data(template_data_loc,data,table_box)
    
#     #Extract Table Data
#     table_output = organise_table(table_data)
#     merged_table = table_merger(table_output)
    
#     row_data = []
#     for i in merged_table[1::]:
#         temp = []
#         for j in i:
#             if j!='-':
#                 temp.append(j[1][0])
#             else:
#                 temp.append(None)
#         row_data.append(temp)
#     header_data = []
#     for i in merged_table[0]:
#         if i!='-':
#             header_data.append(i[1][0])
#         else:
#             header_data.append(None)
#     extracted_data["table_data"] = {"header":header_data,"row_data":row_data}
#     return extracted_data

In [271]:
def order_by_tbyx(ocr_info):
    output = sorted(ocr_info,key=lambda r:(r[0][1],r[0][0]))
    for i in range(len(output)-1):
        for j in range(i,0,-1):
            if abs(output[j+1][0][1]-output[j][0][1])<20 and (output[j+1][0][0]<output[j][0][0]):
                temp = deepcopy(output[j])
                output[j] = deepcopy(output[j+1])
                output[j+1] = deepcopy(temp)
            else:
                break
    return output

In [272]:
def extract_requested_data(template_data,data,table_box):
    table_box_xmax = table_box[2]
    table_box_xmin = table_box[0]
    table_box_ymax = table_box[3]
    table_box_ymin = table_box[1]
    
    table_data = []
    output = {}

    for k,v in template_data.items():
        template_xmax = max([n[0] for n in v])
        template_xmin = min([n[0] for n in v])
        template_ymax = max([n[1] for n in v])
        template_ymin = min([n[1] for n in v])
            
        for i in data:
            data_xmax = max([j[0] for j in i[0]])
            data_xmin = min([j[0] for j in i[0]])
            data_ymax = max([j[1] for j in i[0]])
            data_ymin = min([j[1] for j in i[0]])
            
            if k=='table_data' and not (data_xmax>table_box_xmax and data_xmin>table_box_xmax) and not (data_xmax<table_box_xmin and data_xmin<table_box_xmin) and not (data_ymax>table_box_ymax and data_ymin>table_box_ymax) and not (data_ymax<table_box_ymin and data_ymin<table_box_ymin):
                table_data.append(i)
            elif not (data_xmax>template_xmax and data_xmin>template_xmax) and not (data_xmax<template_xmin and data_xmin<template_xmin) and not (data_ymax>template_ymax and data_ymin>template_ymax) and not (data_ymax<template_ymin and data_ymin<template_ymin):
                if k in output.keys():
                    output[k] = output[k] + " " + i[1][0]
                else:
                    output[k] = i[1][0]
    return table_data, output

In [273]:
def organise_table(table_data):
    data_copy = deepcopy(table_data)
    
    first_xmin = min([i[0] for i in table_data[0][0]])
    org_data = []
    temp = []
    for i in table_data:
        i_xmax = max([j[0] for j in i[0]])
        i_xmin = min([j[0] for j in i[0]])
        if i_xmax<first_xmin:
            org_data.append(temp)
            temp = []
        first_xmin = i_xmin
        temp.append(i)
    org_data.append(temp)
    
    max_rows = len(org_data)
    all_column_len = max([(n,len(i)) for n,i in enumerate(org_data)], key = lambda x: x[1])
    row_of_max_column = all_column_len[0]
    max_columns = all_column_len[1]
    output = [['-']*max_columns for _ in range(max_rows)]
    
    for col,i in enumerate(org_data[row_of_max_column]):
        i_xmax = max([k[0] for k in i[0]])
        i_xmin = min([k[0] for k in i[0]])
        for row,j in enumerate([j[0] for j in org_data]):
            j_ymax = max([k[1] for k in j[0]])
            j_ymin = min([k[1] for k in j[0]])
                
            for k in data_copy:
                k_xmax = max([n[0] for n in k[0]])
                k_xmin = min([n[0] for n in k[0]])
                k_ymax = max([n[1] for n in k[0]])
                k_ymin = min([n[1] for n in k[0]])
                
                if not (k_xmax>i_xmax and k_xmin>i_xmax) and not (k_xmax<i_xmin and k_xmin<i_xmin) and not (k_ymax>j_ymax and k_ymin>j_ymax) and not (k_ymax<j_ymin and k_ymin<j_ymin):
                    if output[row][col] == '-':
                        output[row][col] = k
                    else:
                        ymax = max([n[1] for n in output[row][col][0]])
                        if ymax<k_ymax:
                            new_entry = merge_words(output[row][col],k)
                        else:
                            new_entry = merge_words(k,output[row][col])
                        output[row][col] = new_entry
                    data_copy.remove(k)
                    break
              
    for i in output[::-1]:
        if all([j == '-' for j in i]):
            output.remove(i)
    
#     if len(org_data[0]) < sum([1 if i!='-' else 0 for i in output[0]]):
#         output[0] = org_data[0]
    return output

In [274]:
def merge_words(first,second):
    first_xmax = max([i[0] for i in first[0]])
    first_xmin = min([i[0] for i in first[0]])
    first_ymax = max([i[1] for i in first[0]])
    first_ymin = min([i[1] for i in first[0]])
    
    second_xmax = max([i[0] for i in second[0]])
    second_xmin = min([i[0] for i in second[0]])
    second_ymax = max([i[1] for i in second[0]])
    second_ymin = min([i[1] for i in second[0]])
    
    pos_arr = [[min(first_xmin,second_xmin),min(first_ymin,second_ymin)],[max(first_xmax,second_xmax),min(first_ymin,second_ymin)],[max(first_xmax,second_xmax),max(first_ymax,second_ymax)],[min(first_xmin,second_xmin),max(first_ymax,second_ymax)]]
    item_tuple = (first[1][0]+' '+second[1][0],(first[1][1]+second[1][1])*0.5)
    
    return [pos_arr,item_tuple]

In [275]:
def table_merger(table_data):
    data_copy = deepcopy(table_data)
    
    while(1):
        for i in data_copy[::-1]:
            if all([j == '-' for j in i]):
                data_copy.remove(i)
        row_sum = []
        for i in data_copy:
            row_sum.append(sum([1 if j!='-' else 0 for j in i]))
        if all([i>1 for i in row_sum]):
            break
            
        min_dist = 1000000
        min_elem = []
        for i in data_copy[row_sum.index(min(row_sum))]:
            if i == '-':
                continue
            i_xave = sum([j[0] for j in i[0]])/4.0
            i_yave = sum([j[1] for j in i[0]])/4.0
            for j in data_copy[row_sum.index(min(row_sum))-1]:
                if j == '-':
                    continue
                j_xave = sum([k[0] for k in j[0]])/4.0
                j_yave = sum([k[1] for k in j[0]])/4.0
                
                dist = ((i_xave-j_xave)**2 + (i_yave-j_yave)**2)**0.5
                if dist<min_dist:
                    min_dist = dist
                    if i_yave<j_yave:
                        min_elem = [i,j]
                    else:
                        min_elem = [j,i]  
        new_entry = merge_words(min_elem[0],min_elem[1])
        for row,i in enumerate(data_copy):
            for col,j in enumerate(i):
                if j == min_elem[0]:
                    data_copy[row][col] = new_entry
                elif j == min_elem[1]:
                    data_copy[row][col] = '-'
        print("running loop")        
    return data_copy

In [311]:
def unrotate(data,image,template_size):
    initializers = ['Description'] 
    initializing_point = None
    headers = []
    for i in initializers:
        for j in data:
            if i.casefold() in j[1][0].casefold() or SequenceMatcher(None, i.casefold(), j[1][0].casefold()).ratio()>0.7:
                j_ymax = max([k[1] for k in j[0]])
                j_ymin = min([k[1] for k in j[0]])
                initializing_point = j
                
                for k in data:
                    k_ymax = max([n[1] for n in k[0]])
                    k_ymin = min([n[1] for n in k[0]])
                    
                    if not (k_ymax>j_ymax and k_ymin>j_ymax) and not (k_ymax<j_ymin and k_ymin<j_ymin) and not any([n in k[1][0] for n in '1234567890']):
                        headers.append(k)
        if len(headers)>0:
            break
    prev_header_len = len(headers)
    limit = 1
    prev_error = 0
    while(limit!=5):
        all_angles = []
        initializing_xave =  sum([i[0] for i in initializing_point[0]])/4.0
        initializing_yave =  sum([i[1] for i in initializing_point[0]])/4.0
        initializer_mid_point = (initializing_xave,initializing_yave)
        for i in headers:
            i_xave = sum([j[0] for j in i[0]])/4.0
            i_yave = sum([j[1] for j in i[0]])/4.0
            dist = ((i_xave-initializing_xave)**2 + (i_yave-initializing_yave)**2)**0.5
            i_xlen = max([j[0] for j in initializing_point[0]]) - min([j[0] for j in initializing_point[0]])
            if i!=initializing_point and dist<i_xlen*1.5:
                angle = np.arctan2(i_yave-initializing_yave,i_xave-initializing_xave)
                if angle!=0:
                    sign = int(angle/abs(angle))
                else:
                    sign = 1
                dist_to_180 = np.pi - abs(angle)
                dist_to_0 = abs(angle)
                if dist_to_0<dist_to_180:
                    all_angles.append((-1*sign,dist_to_0))
                else:
                    all_angles.append((sign,dist_to_180))
                    
        angle = sorted(all_angles,key=lambda x:x[1])[0]
        data = rotator(data,angle[0]*angle[1],initializer_mid_point)
        headers = header_checker(data,initializers)

        prev_header_len = len(headers)
        prev_error += angle[0]*angle[1]
        limit += 1
    image = image.rotate(-np.degrees(prev_error),center = initializer_mid_point)
    
    
    return image

In [312]:
def rotator(data,angle,initializing_point):
    for i in data:
        points = i[0]
        i_xave = sum([j[0] for j in i[0]])/4.0
        i_yave = sum([j[1] for j in i[0]])/4.0
        rotated_center = rotate_about_point(points,origin = initializing_point,angle = angle)
        i[0] = rotate_about_point(rotated_center,origin=(i_xave,i_yave),angle = -angle)
    return data

In [313]:
def header_checker(data,initializers):
    headers = []
    for i in initializers:
        for j in data:
            if i.casefold() in j[1][0].casefold() or SequenceMatcher(None, i.casefold(), j[1][0].casefold()).ratio()>0.7:
                j_ymax = max([k[1] for k in j[0]])
                j_ymin = min([k[1] for k in j[0]])

                for k in data:
                    k_ymax = max([n[1] for n in k[0]])
                    k_ymin = min([n[1] for n in k[0]])

                    if not (k_ymax>j_ymax and k_ymin>j_ymax) and not (k_ymax<j_ymin and k_ymin<j_ymin) and not any([n in k[1][0] for n in '1234567890']):
                        headers.append(k)
        if len(headers)>0:
            break
    return headers

In [314]:
def rotate_about_point(p, origin=(0, 0), angle=0):
    R = np.array([[np.cos(angle), -np.sin(angle)],
                  [np.sin(angle),  np.cos(angle)]])
    o = np.atleast_2d(origin)
    p = np.atleast_2d(p)
    return np.squeeze((R @ (p.T-o.T) + o.T).T).tolist()

In [315]:
# img_path = "../data/R-4-1.jpg"
img_path = "../data/rotational_test_sample.jpg"
template_data_loc = {"Invoice Number":[[731.0, 461.0], [927.0, 438.0], [932.0, 486.0], [736.0, 509.0]],
                     "Date":[[182.0, 666.0], [855.0, 578.0], [862.0, 634.0], [189.0, 722.0]],
                    "Total":[[962.0, 2003.0], [1118.0, 1978.0], [1126.0, 2037.0], [971.0, 2062.0]],
                    "table_data":[[43.0, 650.0], [1147.0, 650.0], [1147.0, 1407.0], [43.0, 1407.0]]}
template_size = (1215, 2689)

image = Image.open(img_path).convert("RGB")
start_time = time.time()
output = run_ocr_merger(image,template_data_loc,template_size)
print("Time Taken: %s seconds" % (time.time() - start_time))

Time Taken: 22.855565071105957 seconds


In [316]:
output