In [1]:
from paddleocr import PaddleOCR
from PIL import Image
import torch
from huggingface_hub import hf_hub_download
from transformers import DetrFeatureExtractor
from transformers import TableTransformerForObjectDetection
import numpy as np
from matplotlib import pyplot as plt
import cv2
import threading
ocr_model = PaddleOCR(lang='en',use_angle_cls=False,show_log=False)

In [2]:
def scan(img,orig_img):
    # Repeated Closing operation to remove text from the document.
    kernel = np.ones((5, 5), np.uint8)
    img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel, iterations=3)
    canny = cv2.Canny(img, 70, 300)
    canny = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
    show(canny)
    
    # Finding contours for the detected edges.
    contours, hierarchy = cv2.findContours(canny, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    # Keeping only the largest detected contour.
    page = sorted(contours, key=cv2.contourArea, reverse=True)[:5]
 
    # Detecting Edges through Contour approximation.
    # Loop over the contours.
    if len(page) == 0:
        return orig_img
    for c in page:
        # Approximate the contour.
        epsilon = 0.02 * cv2.arcLength(c, True)
        corners = cv2.approxPolyDP(c, epsilon, True)
        # If our approximated contour has four points.
        if len(corners) == 4:
            break
    # Sorting the corners and converting them to desired shape.
    corners = sorted(np.concatenate(corners).tolist())
    # For 4 corner points being detected.
    corners = order_points(corners)
 
    destination_corners = find_dest(corners)
 
    h, w = orig_img.shape[:2]
    # Getting the homography.
    M = cv2.getPerspectiveTransform(np.float32(corners), np.float32(destination_corners))
    # Perspective transform using homography.
    final = cv2.warpPerspective(orig_img, M, (destination_corners[2][0], destination_corners[2][1]),flags=cv2.INTER_LINEAR)
    return final

In [3]:
def order_points(pts):
    '''Rearrange coordinates to order:
      top-left, top-right, bottom-right, bottom-left'''
    rect = np.zeros((4, 2), dtype='float32')
    pts = np.array(pts)
    s = pts.sum(axis=1)
    # Top-left point will have the smallest sum.
    rect[0] = pts[np.argmin(s)]
    # Bottom-right point will have the largest sum.
    rect[2] = pts[np.argmax(s)]
 
    diff = np.diff(pts, axis=1)
    # Top-right point will have the smallest difference.
    rect[1] = pts[np.argmin(diff)]
    # Bottom-left will have the largest difference.
    rect[3] = pts[np.argmax(diff)]
    # return the ordered coordinates
    return rect.astype('int').tolist()

In [4]:
def find_dest(pts):
    (tl, tr, br, bl) = pts
    # Finding the maximum width.
    widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
    widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
    maxWidth = max(int(widthA), int(widthB))
 
    # Finding the maximum height.
    heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
    heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
    maxHeight = max(int(heightA), int(heightB))
    # Final destination co-ordinates.
    destination_corners = [[0, 0], [maxWidth, 0], [maxWidth, maxHeight], [0, maxHeight]]
 
    return order_points(destination_corners)

In [5]:
def scaler(box,xscale,yscale,xmax,ymax):
    box[0]*=1-xscale
    box[1]*=1-yscale
    box[2]*=1+xscale
    box[3]*=1+yscale
    if box[2]>xmax:
        box[2] = xmax
    if box[3]>ymax:
        box[3] = ymax
    return box

In [6]:
def table_detection(image):
    width, height = image.size
    feature_extractor = DetrFeatureExtractor()
    encoding = feature_extractor(image, return_tensors="pt")
    model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")

    with torch.no_grad():
        outputs = model(**encoding)

    results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]

    cropped_img = []
    scale = 0.1
    all_boxes = []
    for i in range(len(results['scores'])):
        bounding_box = scaler(results['boxes'][i].tolist(),scale,scale,width,height)
        all_boxes.append(bounding_box)
        cropped_img.append(image.crop(bounding_box))
    
    for i in all_boxes:
        mask_height = int(i[3] - i[1])
        mask_width = int(i[2] - i[0])
        square = np.full((mask_height, mask_width),255)
        square_img = Image.fromarray(square.astype(np.uint8))
        image.paste(square_img,(int(i[0]),int(i[1])))
        
    return cropped_img,image

In [7]:
def table_structure_detection(image):
    width, height = image.size
    feature_extractor = DetrFeatureExtractor()
    encoding = feature_extractor(image, return_tensors="pt")
    model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
    
    with torch.no_grad():
        outputs = model(**encoding)
    target_sizes = [image.size[::-1]]
    results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0]
    
    output = {"headers":[],"row_data":[]}
    y_scale = 0.03
    
    #labels == 1 is col_data, 2 is row_data, 3 is col_header_data, 4 is row_header_data
    threads = []
    result = []
    lock = threading.Lock()
    for i in range(len(results['boxes'])):
        if results['labels'][i] == 4:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            row_header_img = image.crop(bounding_box)
            np_img = np.asarray(row_header_img)
            result = ocr_model.ocr(np_img)
            for i in result[0]:
                output["headers"].append(i[1][0])
        elif results['labels'][i] == 2:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            row_img = image.crop(bounding_box)
            np_img = np.asarray(row_img)
            threads.append(threading.Thread(target=ocr_thread, args=(np_img,result,lock)))
    
    for i in threads:
        i.start()
    for i in threads:
        i.join()
        
    for i in result:
        row_entry = []
        for j in i[0]:
            row_entry.append(j[1][0])
        output["row_data"].append(row_entry)
    return output

In [8]:
def ocr_thread(np_img,result,lock):
    data = ocr_model.ocr(np_img)
    lock.acquire()
    result.append(data)
    lock.release()

In [9]:
def run_ocr_no_temp(img_path,edge_detect=False):
    if edge_detect:
        #Peform edge detection
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        orig_img = cv2.imread(img_path)
        output = scan(img,orig_img)
        show(output)
        img_path = img_path[:-4]+ "_cropped"+img_path[-4:]
        cv2.imwrite(img_path,output)
        
    #Search for table
    image = Image.open(img_path).convert("RGB")
    table,image_no_table = table_detection(image)
    
    #Extract unorganized data
    unorganized_data = np.asarray(image_no_table)
    result = ocr_model.ocr(unorganized_data)
    data_ls = []
    for i in result[0]:
        data_ls.append(i[1][0])
    
    extracted_data = {"unorganized_data":data_ls,"Table_Data":{"headers":[],"row_data":[],}}
    
    #Extract Table Data
    table_output = []
    for i in table:
        table_output.append(table_structure_detection(i))
    
    #Get possible header info
    for n in range(len(table_output)):
        if len(table_output[n]["headers"]) == 0:
            for i in table_output[n]["row_data"]:
                data_str = (",").join(i)
                if not any(j in data_str for j in '1234567890'):
                    table_output[n]["headers"].append(i)
                    table_output[n]["row_data"].remove(i)

    #Join Data
    for n in range(len(table_output)):
        extracted_data["Table_Data"]["headers"].append(table_output[n]["headers"])
        extracted_data["Table_Data"]["row_data"].append(table_output[n]["row_data"])
    
    return extracted_data

In [10]:
img_path = "../data/invoice_sample.jpg"
run_ocr_no_temp(img_path,edge_detect=False)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/115M [00:00<?, ?B/s]

TypeError: Cannot handle this data type: (1, 1), <i8