In [5]:
##Install Prerequesites
#!pip install -q git+https://github.com/huggingface/transformers.git
#!pip install -q timm

In [1]:
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import DetrFeatureExtractor
from transformers import TableTransformerForObjectDetection

In [2]:
def scaler(box,xscale,yscale,xmax,ymax):
    box[0]*=1-xscale
    box[1]*=1-yscale
    box[2]*=1+xscale
    box[3]*=1+yscale
    if box[2]>xmax:
        box[2] = xmax
    if box[3]>ymax:
        box[3] = ymax
    return box

In [3]:
def table_detection(image):
    width, height = image.size
    feature_extractor = DetrFeatureExtractor()
    encoding = feature_extractor(image, return_tensors="pt")
    model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")

    with torch.no_grad():
        outputs = model(**encoding)

    results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]

    cropped_img = []
    scale = 0.05
    for i in range(len(results['scores'])):
        bounding_box = scaler(results['boxes'][i].tolist(),scale,scale,width,height)
        cropped_img.append(image.crop(bounding_box))
    
    return cropped_img

In [25]:
def table_structure_detection(image):
    width, height = image.size
    feature_extractor = DetrFeatureExtractor()
    encoding = feature_extractor(image, return_tensors="pt")
    model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
    
    with torch.no_grad():
        outputs = model(**encoding)
    target_sizes = [image.size[::-1]]
    results = feature_extractor.post_process_object_detection(outputs, threshold=0.8, target_sizes=target_sizes)[0]
    
    output = {'col_data':[],'row_data':[],'col_header_data':[],'row_header_data':[]}
    
    #plot_results(model,image, results['scores'], results['labels'], results['boxes'])
    y_scale = 0.03
    for i in range(len(results['boxes'])):
        if results['labels'][i] == 1:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            output['col_data'].append(image.crop(bounding_box))
        elif results['labels'][i] == 2:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            output['row_data'].append(image.crop(bounding_box))
        elif results['labels'][i] == 3:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            output['col_header_data'].append(image.crop(bounding_box))
        elif results['labels'][i] == 4:
            bounding_box = scaler(results['boxes'][i].tolist(),1,y_scale,width,height)
            output['row_header_data'].append(image.crop(bounding_box))
    
    return output

In [21]:
import matplotlib.pyplot as plt

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def plot_results(model,pil_img, scores, labels, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for score, label, (xmin, ymin, xmax, ymax),c  in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        text = f'{model.config.id2label[label]}: {score:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

In [None]:
def extract_data(split_table):
    output = []
    for i in split_table:
        #Ran OCR to get a split list output
        output.append(OCR_extract(i))
        
    return output

In [36]:
def get_table_data(img):
    #Locate tables on the image
    tables = table_detection(img)
    
    #Figure out headers and data inside the table
    split_tables = {}
    for i,table in enumerate(tables):
        split_tables[i] = table_structure_detection(table)
    
    #Perform OCR on each table detected
#     data = {}
#     for i,split_table in enumerate(split_tables):
#         if len(table['col_header_data'])>0:
#             data[i]['header_data'] = extract_data(split_table['col_header_data'])
#         if len(table['row_data'])>0:
#             data[i]['row_data'] = extract_data(split_table['row_data'])
            
    return data

In [37]:
file_path = "test_images/invoice_sample.jpg"
image = Image.open(file_path).convert("RGB")

data = get_table_data(image)