In [1]:
from paddleocr import PaddleOCR
from PIL import Image
import torch
from transformers import DetrFeatureExtractor
from transformers import TableTransformerForObjectDetection
import numpy as np
from matplotlib import pyplot as plt
import cv2
from threading import Thread
from threading import RLock
ocr_model = PaddleOCR(lang='en',use_angle_cls=False,show_log=False)

In [2]:
def scan(img,orig_img):
    # Repeated Closing operation to remove text from the document.
    kernel = np.ones((5, 5), np.uint8)
    img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel, iterations=3)
    canny = cv2.Canny(img, 70, 300)
    canny = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
    
    # Finding contours for the detected edges.
    contours, hierarchy = cv2.findContours(canny, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    # Keeping only the largest detected contour.
    page = sorted(contours, key=cv2.contourArea, reverse=True)[:5]
 
    # Detecting Edges through Contour approximation.
    # Loop over the contours.
    if len(page) == 0:
        return orig_img
    for c in page:
        # Approximate the contour.
        epsilon = 0.02 * cv2.arcLength(c, True)
        corners = cv2.approxPolyDP(c, epsilon, True)
        # If our approximated contour has four points.
        if len(corners) == 4:
            break
    # Sorting the corners and converting them to desired shape.
    corners = sorted(np.concatenate(corners).tolist())
    # For 4 corner points being detected.
    corners = order_points(corners)
 
    destination_corners = find_dest(corners)
 
    h, w = orig_img.shape[:2]
    # Getting the homography.
    M = cv2.getPerspectiveTransform(np.float32(corners), np.float32(destination_corners))
    # Perspective transform using homography.
    final = cv2.warpPerspective(orig_img, M, (destination_corners[2][0], destination_corners[2][1]),flags=cv2.INTER_LINEAR)
    return Image.fromarray(final)

In [3]:
def order_points(pts):
    '''Rearrange coordinates to order:
      top-left, top-right, bottom-right, bottom-left'''
    rect = np.zeros((4, 2), dtype='float32')
    pts = np.array(pts)
    s = pts.sum(axis=1)
    # Top-left point will have the smallest sum.
    rect[0] = pts[np.argmin(s)]
    # Bottom-right point will have the largest sum.
    rect[2] = pts[np.argmax(s)]
 
    diff = np.diff(pts, axis=1)
    # Top-right point will have the smallest difference.
    rect[1] = pts[np.argmin(diff)]
    # Bottom-left will have the largest difference.
    rect[3] = pts[np.argmax(diff)]
    # return the ordered coordinates
    return rect.astype('int').tolist()

In [4]:
def find_dest(pts):
    (tl, tr, br, bl) = pts
    # Finding the maximum width.
    widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
    widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
    maxWidth = max(int(widthA), int(widthB))
 
    # Finding the maximum height.
    heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
    heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
    maxHeight = max(int(heightA), int(heightB))
    # Final destination co-ordinates.
    destination_corners = [[0, 0], [maxWidth, 0], [maxWidth, maxHeight], [0, maxHeight]]
 
    return order_points(destination_corners)

In [5]:
def show(img):
    plt.subplot(122),plt.imshow(img,cmap = 'gray')
    plt.title('Image'), plt.xticks([]), plt.yticks([])
    plt.show()

In [6]:
def get_midpoints(data):
    output = {}
    for i in data:
        x_pos = sum([j[0] for j in i[0]])/4.0
        y_pos = sum([j[1] for j in i[0]])/4.0
        output[(x_pos,y_pos)] = i[1][0]
    return output

In [7]:
def extract_requested_data(template_data_loc,data):
    data_loc = get_midpoints(data)
    output = {}
    for k,v in data_loc.items():
        for i,j in template_data_loc.items():
            x_min,x_max = j[0][0],j[1][0]
            y_min,y_max = j[0][1],j[2][1]
            if (x_min<k[0]<x_max and y_min<k[1]<y_max):
                if i in output.keys():
                    output[i] = output[i] + " " + v
                else:
                    output[i] = v
    return output

In [8]:
def scaler(box,xscale,yscale,xmax,ymax):
    box[0]*=1-xscale
    box[1]*=1-yscale
    box[2]*=1+xscale
    box[3]*=1+yscale
    if box[2]>xmax:
        box[2] = xmax
    if box[3]>ymax:
        box[3] = ymax
    return box

In [9]:
class ThreadSafeList:
    def __init__(self):
        self._list = list()
        self._lock = RLock()
    def append(self,data):
        with self._lock:
            self._list.append(data)
    def __getitem__(self, key):
        return self._list[key]
    def __len__(self):
        return len(self._list)

In [10]:
def ocr_thread(np_img,result):
    data = ocr_model.ocr(np_img)
    result.append(data)

In [11]:
def table_structure_detection(image):
    global lock
    width, height = image.size
    feature_extractor = DetrFeatureExtractor()
    encoding = feature_extractor(image, return_tensors="pt")
    model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
    
    with torch.no_grad():
        outputs = model(**encoding)
    target_sizes = [image.size[::-1]]
    results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0]
    
    output = {"headers":[],"row_data":[]}
    y_scale = 0.03
    
    threads = []
    result = ThreadSafeList()
    for i in range(len(results['boxes'])):
        if results['labels'][i] == 4:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            row_header_img = image.crop(bounding_box)
            np_img = np.asarray(row_header_img)
            header_result = ocr_model.ocr(np_img)
            for i in header_result[0]:
                output["headers"].append(i[1][0])
        elif results['labels'][i] == 2:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            row_img = image.crop(bounding_box)
            np_img = np.asarray(row_img)
            threads.append(Thread(target=ocr_thread, args=(np_img,result)))
    for i in threads:
        i.start()
    for i in threads:
        i.join()
    for i in range(len(result)):
        row_entry = []
        for j in result[i][0]:
            row_entry.append(j[1][0])
        output["row_data"].append(row_entry)
    return output

In [12]:
def run_ocr_with_temp(image,template_data_loc,template_size,edge_detect=False):
    table_box = (template_data_loc["Table_Data"][0][0],template_data_loc["Table_Data"][0][1],
             template_data_loc["Table_Data"][1][0],template_data_loc["Table_Data"][2][1])
    
    if edge_detect:
        #Peform edge detection
        img = cv2.cvtColor(np.array(image), cv2.IMREAD_GRAYSCALE)
        orig_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        image = scan(img,orig_img)

    #Perform image resizing
    image = image.resize(template_size)
    
    #Extract data requested from template 
    template_data = np.asarray(image)
    data = ocr_model.ocr(template_data)[0]
    extracted_data = extract_requested_data(template_data_loc,data)
    
    #Extract Table Data
    table = image.crop(table_box)
    table_output = table_structure_detection(table)

    #Get possible header info
    if len(table_output["headers"]) == 0:
        for i in table_output["row_data"]:
            data_str = (",").join(i)
            if not any(j in data_str for j in '1234567890'):
                table_output["headers"].append(i)
                table_output["row_data"].remove(i)

    #Join Data
    extracted_data["Table_Data"] = table_output
    return extracted_data

In [13]:
img_path = "../data/invoice_sample.jpg"
template_data_loc = {"Invoice Number":[[967.0, 365.0], [1074.0, 365.0], [1074.0, 386.0], [967.0, 386.0]],
                     "Date":[[781.0, 365.0], [876.0, 365.0], [876.0, 386.0], [781.0, 386.0]],
                    "Address":[[89.0, 605.0], [269.0, 605.0], [269.0, 661.0], [89.0, 661.0]],
                    "Table_Data":[[85.0, 840.0], [1160.0, 840.0], [1160.0, 1310.0], [85.0,1310.0]]}
template_size = (1240, 1754)

image = Image.open(img_path).convert("RGB")
run_ocr_with_temp(image,template_data_loc,template_size,edge_detect=False)



{'Date': '14/08/2023',
 'Invoice Number': 'F1000876/23',
 'Address': '255 Commercial Street 25880 New York, US',
 'Table_Data': {'headers': [['PRODUCT',
    'HS CODE',
    'UNITS',
    'UNIT PRICE',
    'TOTAL']],
  'row_data': [['Country of origin: US',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Pole with bracket'],
   ['Polewitn bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: Us'],
   ['Polewitn bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: Us'],
   ['Pole with bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: US'],
   ['Pole with bracket',
    '88565.2545',
    '1',
    '$85.00',
    '$85.00',
    'Country of origin: Us']]}}