## Imports

In [None]:
from pathlib import Path
import json
from collections import defaultdict

In [None]:
import cv2
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytesseract as tess
from PIL import Image
from sklearn.cluster import DBSCAN

## Constants

In [None]:
TESS_SCALE = 5
DATA_DIR = Path(r"C:\Users\stans\Documents\Projects\Datasets\pubtabnet.tar\pubtabnet\pubtabnet")  # C:\datasets\pubtabnet
TEST_DIR = DATA_DIR / "train"

# Used for drawing boxes
COLORS = {
    'black':(0,0,0),
    'red':(150,0,0),
    'green':(0,150,0),
    'blue':(0,0,150)
}

FORMAT_CHARS = [
    '<b>',
    '</b>',
    '<i>',
    '</i>',
    '<sup>',
    '</sup>',
    '<sub>',
    '</sub>',
]

## Image Utils

Functions:
* `scale(im, scale)`: used to scale images (needed for accurate results from tesseract)
* `draw(im, bbox, fmt='cv2', color='black')`: draws a rectangle

In [None]:
def scale(im, scale):
    return im.resize((int(x*scale) for x in im.size))

def draw(im, bbox, fmt='cv2', color='black'):
    im = np.array(im)  # in case it is in PIL format
    if fmt=='cv2':
        x,y,w,h = bbox
    elif fmt == 'pil':
        x1,y1,x2,y2 = bbox
        x,y,w,h = x1,y1,int((x1+x2)/2),int((y1+y2)/2)
    
    W,H,_ = im.shape
    
    if isinstance(color, str):
        color = COLORS[color]
    
    pt1 = (x, y)
    pt2 = (x+w, y+h)
    return cv2.rectangle(im, pt1=pt1, pt2=pt2,
        color=color, thickness=1
    )

## Load Data

In [None]:
files = list(TEST_DIR.glob("*.png"))

In [None]:
with open(DATA_DIR / "PubTabNet_2.0.0.jsonl", "r", encoding="utf8") as f:
    json_list = list(f)

In [None]:
fname_index_map = {}

for i, json_str in enumerate(json_list):
    result = json.loads(json_str)
    fname_index_map[result["filename"]] = i

## Data Utils
Utility functions for interacting with the PubTabNet Dataset


Functions:
* `X(i)`: returns the i'th image as a PIL.Image object
* `Y(i)`: returns the i'th json object (dict) representing the target of our model
* `get_data(i)`: similar to `Y(i)` but in dataframe format

In [None]:
def Y(i):
    fname = files[i]
    i = fname_index_map[fname.parts[-1]]
    return json.loads(json_list[i])

def X(i):
    return Image.open(files[i])

In [None]:
def html_iterator(y):
    html = y['html']['structure']['tokens']
    
    row = 0
    column = 0
    is_head = False
    is_body = False
    row_start=None
    row_end=None
    col_start=0
    col_end=0
    raw = ''
    
    for t in html:
        raw += t
        if '<td' in t:
            col_start = column
        
        elif t == '</td>':
            col_end = column
            row_end = row
            column += 1
            yield {
                'col_start':col_start,
                'col_end':col_end,
                'row_start':row_start,
                'row_end':row_end,
                'is_head':is_head,
                'is_body':is_body,
                'raw': raw,
            }
            raw = ''
            
        elif t == '<thead>':
            row_start = row
            is_head = True
            
        elif t == '</thead>':
            is_head = False
            
        elif t == '<tr>':
            row_start = row
            
        elif t == '</tr>':
            column = 0
            row += 1
            
        elif "colspan" in t:
            # extract the int from the string and increment the row counter
            column += int(t.split('"')[1]) - 1
        
        elif t == '<tbody>':
            is_body = True
            
        elif t == '</tbody>':
            is_body = False
            
        elif t in ['>']:
            pass
        
        else:
            raise ValueError(t)

In [None]:
def clean_text(cell):
    text = ''.join(cell["tokens"]).strip()
    for fmt in FORMAT_CHARS:
        text = text.replace(fmt, '')
    return text

In [None]:
def get_data(i):
    data = Y(i)
    struct = html_iterator(data)
    
    rows = []
    for cell in data['html']['cells']:
        cell_data = next(struct)
        if 'bbox' in cell:
            x1,y1,x2,y2 = cell["bbox"]
            x,y,w,h = x1,y1,x2-x1,y2-y1
            cell_data['text'] = clean_text(cell),
            cell_data['bbox'] = x,y,w,h
            rows.append(cell_data)
            
    df = pd.DataFrame(rows)
    # pandas wants to convert the str to a tuple for some reason
    # so we force it back to a str here
    df['text'] = [x[0] for x in df['text']]
    return df

In [None]:
def display_ground_truth_bboxes(i):
    im = X(i)
    data = Y(i)
    bboxes = [x['bbox'] for x in data['html']['cells'] if 'bbox' in x]
    for bbox in bboxes:
        im = draw(im, bbox, fmt='pil', color='green')
    return Image.fromarray(im)

In [None]:
display_ground_truth_bboxes(1)

## Tesseract Utils
From the image above we can see that the bounding boxes in our dataset suck.<br>
The one redeeming quality is that the top-left corner seems pretty accurate.<br>
So, the plan is to use the boxes returned by Tesseract to clean it up.

Functions:
* `get_tessdata(i)`: filters output of tess.image_to_data to get a dataframe of words and bboxes for the i'th image
* `display_tess_bboxes(i)`: displays the result of get_tessdata

In [None]:
def tessdata_to_df(tessdata, keep_garbage=False):
    """Ingests a string repr of tesseract output and spits out a dataframe"""
    rows = [r.split("\t") for r in tessdata.split("\n")[:-1]]
    h = rows[0]
    rows = rows[1:]
        
    df = pd.DataFrame(rows)
    df.columns = h
    
    # set types
    dtypes = [int]*10 + [float, str]
    for c,t in zip(df.columns, dtypes):
        df[c] = df[c].values.astype(t)
    
    if not keep_garbage:
        df = df[[x.strip() != "" for x in df["text"]]]
        df = df[df["conf"] > 0].reset_index()
    
    return df


def fit_bboxes_to_text(im, tessdata):
    """
    Tesseract bboxes sometimes have very wide margins on them. This is bad for the purpose
    of determining the grid that defines the layout.
    Shrinks each side of the bbox until ~8% of the darkness is lost.
    
    im <PIL.Image or np.array()>: image used for cropping out the boxes defined in tessdata
    
    tessdata <Pandas.DataFrame>: Should be a dataframe with columns ["left","top","width","height"]
        as is returned by tessdata_to_df
    """
    df_left = tessdata["left"].values
    df_top = tessdata["top"].values
    df_width = tessdata["width"].values
    df_height = tessdata["height"].values
    
    im = np.array(im).sum(axis=2)
    im = im / im.max()
    im = abs(im - 1)
    for i, bbox in enumerate(tessdata[['left', 'top', 'width', 'height']].values):
        x,y,w,h = bbox
        cropped = im[y:y+h, x:x+w]
        
        v_sum = cropped.sum(axis=1)
        h_sum = cropped.sum(axis=0)
        
        top = 0
        while sum(v_sum[top:]) > sum(v_sum)*0.98 and top < len(v_sum)-3:
            top += 1
            
        bottom = len(v_sum)
        while sum(v_sum[top:bottom]) > sum(v_sum)*0.96 and bottom > top+2:
            bottom -= 1
            
            
        left = 0
        while sum(h_sum[left:]) > sum(h_sum)*0.98 and left < len(h_sum)-3:
            left += 1
            
        right = len(h_sum)
        while sum(h_sum[left:right]) > sum(h_sum)*0.96 and right > left+2:
            right -= 1
        
        df_left[i] = x + left
        df_top[i] = y + top
        df_width[i] = right-left
        df_height[i] = bottom-top
        
    df = tessdata.copy()
    df["left"] = df_left
    df["top"] = df_top
    df["width"] = df_width
    df["height"] = df_height
    return df


def im_to_data(im):
    tessdata = tess.image_to_data(im, config="--psm 1")
    return tessdata_to_df(tessdata)

In [None]:
def get_tessdata(i):
    im = scale(X(i), TESS_SCALE)
    # im = np.array(im).min(axis=2) > 100  # convert to bw (maybe unnecessary)
    df = im_to_data(im)
    df[['left','top','width','height']] = df[['left','top','width','height']]//TESS_SCALE
    return df

In [None]:
def display_tess_bboxes(i):
    im = X(i)
    data = get_tessdata(i)
    for bbox in data[['left','top','width','height']].values:
        im = draw(im, bbox, fmt='cv2', color='green')
    return Image.fromarray(im)

In [None]:
display_tess_bboxes(3)

## Fuzzy Matching
Ok, so we have converted the raw input data into a more manageable dataframe via `get_data`.<br>
We also have the text and bounding boxes returned by tesseract.<br>
Tessdata, while cleaner, does miss some of the words entirely.

In [None]:
i = 0
get_tessdata(i)

In [None]:
get_data(i)

In [None]:
f = 0

In [None]:
def ground_truth_grid(i, data=None):
    if data is None:
        data = Y(f)['html']['structure']
    
    assert len(list(data.keys())) == 1, data.keys()
    is_head = False
    is_data = False
    c=-1
    r=-1
    cspan=1
    grid = []
    for t in data['tokens']:
        if t == '<thead>':
            is_head = True
            
        elif t == '</thead>':
            is_head = False
            
        elif t == '<tr>':
            c = -1
            r += 1
            
        elif '<td' in t:
            c += 1
            cspan = 1
            
        elif "colspan" in t:
            cspan = int(t.split('"')[1])
            
        elif t == '</td>':
            grid.append([r,c,cspan,is_head])
            
        elif t in ['>', '</tr>', '<tbody>', '</tbody>']:
            pass
        
        else:
            raise ValueError(t)
            
    return grid

def ground_truth_table(i):
    html = Y(f)['html']
    grid = ground_truth_grid(i, html['structure'])
    values = ["".join(x['tokens']) for x in html["cells"] if 'tokens' in x]
    
    nr = max([x[0] for x in grid])+1
    nc = max([x[1] for x in grid])+1
    rows = []
    for i in range(nr):
        rows.append([])
        for j in range(nc):
            rows[-1].append('')
    
    
    c_offset = 0
    for i,g in enumerate(grid):
        r,c,s,h = g
        if c==0:  # reset the column offset on new row
            c_offset=0
        rows[r][c+c_offset] = values[i]
        c_offset += s-1
    return rows

## Mapping Tessdata to Ground Truth
I don't like the bboxes provided in the dataset.

In [None]:
# f += 1
im = X(f)
data = Y(f)
display(pd.DataFrame(ground_truth_table(f)))
im

In [None]:
ground_truth_grid(f)

In [None]:
["".join(x['tokens']) for x in data['html']["cells"] if 'tokens' in x]