In [1]:
from paddleocr import PaddleOCR
from PIL import Image
import torch
from transformers import DetrFeatureExtractor
from transformers import TableTransformerForObjectDetection
import numpy as np
ocr_model = PaddleOCR(lang='en',use_angle_cls=False,show_log=False)

In [2]:
def get_midpoints(data):
    output = {}
    for i in data:
        x_pos = sum([j[0] for j in i[0]])/4.0
        y_pos = sum([j[1] for j in i[0]])/4.0
        output[(x_pos,y_pos)] = i[1][0]
    return output

In [3]:
def extract_requested_data(template_data_loc,data):
    data_loc = get_midpoints(data)
    output = {}
    for k,v in data_loc.items():
        for i,j in template_data_loc.items():
            x_min,x_max = j[0][0],j[1][0]
            y_min,y_max = j[0][1],j[2][1]
            if (x_min<k[0]<x_max and y_min<k[1]<y_max):
                if i in output.keys():
                    output[i] = output[i] + " " + v
                else:
                    output[i] = v
    return output

In [4]:
def scaler(box,xscale,yscale,xmax,ymax):
    box[0]*=1-xscale
    box[1]*=1-yscale
    box[2]*=1+xscale
    box[3]*=1+yscale
    if box[2]>xmax:
        box[2] = xmax
    if box[3]>ymax:
        box[3] = ymax
    return box

In [5]:
def table_structure_detection(image):
    width, height = image.size
    feature_extractor = DetrFeatureExtractor()
    encoding = feature_extractor(image, return_tensors="pt")
    model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
    
    with torch.no_grad():
        outputs = model(**encoding)
    target_sizes = [image.size[::-1]]
    results = feature_extractor.post_process_object_detection(outputs, threshold=0.8, target_sizes=target_sizes)[0]
    
    output = {"headers":[],"row_data":[]}
    y_scale = 0.03
    
    #labels == 1 is col_data, 2 is row_data, 3 is col_header_data, 4 is row_header_data
    for i in range(len(results['boxes'])):
        if results['labels'][i] == 4:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            row_header_img = image.crop(bounding_box)
            np_img = np.asarray(row_header_img)
            result = ocr_model.ocr(np_img)
            for i in result[0]:
                output["headers"].append(i[1][0])
        elif results['labels'][i] == 2:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            row_img = image.crop(bounding_box)
            np_img = np.asarray(row_img)
            result = ocr_model.ocr(np_img)
            row_data_entry = []
            for i in result[0]:
                row_data_entry.append(i[1][0])
            output["row_data"].append(row_data_entry)
    return output

In [6]:
img_path = "../data/invoice_sample.jpg"
template_data_loc = {"Invoice Number":[[967.0, 365.0], [1074.0, 365.0], [1074.0, 386.0], [967.0, 386.0]],
                     "Date":[[781.0, 365.0], [876.0, 365.0], [876.0, 386.0], [781.0, 386.0]],
                    "Address":[[89.0, 605.0], [269.0, 605.0], [269.0, 661.0], [89.0, 661.0]],
                    "Table_Data":[[85.0, 855.0], [1160.0, 855.0], [1160.0, 1310.0], [85.0,1310.0]]}
table_box = (template_data_loc["Table_Data"][0][0],template_data_loc["Table_Data"][0][1],
             template_data_loc["Table_Data"][1][0],template_data_loc["Table_Data"][2][1],)

#Peform edge detection and image distortion if required

#Extract data requested from template 
data = ocr_model.ocr(img_path)[0]
extracted_data = extract_requested_data(template_data_loc,data)

#Extract Table Data
image = Image.open(img_path).convert("RGB")
table = image.crop(table_box)
table_output = table_structure_detection(table)

#Join Data
extracted_data["Table_Data"] = table_output



In [7]:
extracted_data

{'Date': '14/08/2023',
 'Invoice Number': 'F1000876/23',
 'Address': '255 Commercial Street 25880 New York, US',
 'Table_Data': {'headers': [],
  'row_data': [['Pole with bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: US'],
   ['Pole with bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: US'],
   ['Pole with bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: US'],
   ['Conveyor Belt 25 "',
    '88565.2252',
    '2',
    '$200.00',
    '$400.00',
    'Country of origin: US'],
   ['Pole with bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: US']]}}