## Imports

In [None]:
from pathlib import Path
import json
from collections import defaultdict

In [None]:
import cv2
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytesseract as tess
from PIL import Image
from sklearn.cluster import DBSCAN

## Loading Data

In [None]:
data_dir = Path(r"C:\Users\stans\Documents\Projects\Datasets\pubtabnet.tar\pubtabnet\pubtabnet")
test_dir = data_dir / "train"

In [None]:
files = list(test_dir.glob("*.png"))

In [None]:
with open(data_dir / "PubTabNet_2.0.0.jsonl", "r", encoding="utf8") as f:
    json_list = list(f)

In [None]:
fname_index_map = {}

for i, json_str in enumerate(json_list):
    result = json.loads(json_str)
    fname_index_map[result["filename"]] = i

## Data Utils
Utility functions for interacting with the PubTabNet Dataset

In [None]:
def Y(i):
    fname = files[i]
    i = fname_index_map[fname.parts[-1]]
    return json.loads(json_list[i])

def X(i):
    return Image.open(files[i])

## Image Utils

In [None]:
def scale(im, scale):
    return im.resize((int(x*scale) for x in im.size))


COLORS = {
    'black':(0,0,0),
    'red':(150,0,0),
    'green':(0,150,0),
    'blue':(0,0,150)
}

def draw(im, bbox, color='black'):
    im = np.array(im)  # in case it is in PIL format
    x,y,w,h = bbox
    W,H,_ = im.shape
    
    if isinstance(color, str):
        color = COLORS[color]
    
    pt1 = (x, y)
    pt2 = (x+w, y+h)
    return cv2.rectangle(im, pt1=pt1, pt2=pt2,
        color=color, thickness=1
    )

## Tesseract Utils

In [None]:
def tessdata_to_df(tessdata, keep_garbage=False):
    """Ingests a string repr of tesseract output and spits out a dataframe"""
    rows = [r.split("\t") for r in tessdata.split("\n")[:-1]]
    h = rows[0]
    rows = rows[1:]
        
    df = pd.DataFrame(rows)
    df.columns = h
    
    # set types
    dtypes = [int]*11 + [str]
    for c,t in zip(df.columns, dtypes):
        df[c] = df[c].values.astype(t)
    
    if not keep_garbage:
        df = df[[x.strip() != "" for x in df["text"]]]
        df = df[df["conf"] > 0].reset_index()
    
    return df


def fit_bboxes_to_text(im, tessdata):
    """
    Tesseract bboxes sometimes have very wide margins on them. This is bad for the purpose
    of determining the grid that defines the layout.
    Shrinks each side of the bbox until ~8% of the darkness is lost.
    
    im <PIL.Image or np.array()>: image used for cropping out the boxes defined in tessdata
    
    tessdata <Pandas.DataFrame>: Should be a dataframe with columns ["left","top","width","height"]
        as is returned by tessdata_to_df
    """
    df_left = tessdata["left"].values
    df_top = tessdata["top"].values
    df_width = tessdata["width"].values
    df_height = tessdata["height"].values
    
    im = np.array(im).sum(axis=2)
    im = im / im.max()
    im = abs(im - 1)
    for i, bbox in enumerate(tessdata[['left', 'top', 'width', 'height']].values):
        x,y,w,h = bbox
        cropped = im[y:y+h, x:x+w]
        
        v_sum = cropped.sum(axis=1)
        h_sum = cropped.sum(axis=0)
        
        top = 0
        while sum(v_sum[top:]) > sum(v_sum)*0.98 and top < len(v_sum)-3:
            top += 1
            
        bottom = len(v_sum)
        while sum(v_sum[top:bottom]) > sum(v_sum)*0.96 and bottom > top+2:
            bottom -= 1
            
            
        left = 0
        while sum(h_sum[left:]) > sum(h_sum)*0.98 and left < len(h_sum)-3:
            left += 1
            
        right = len(h_sum)
        while sum(h_sum[left:right]) > sum(h_sum)*0.96 and right > left+2:
            right -= 1
        
        df_left[i] = x + left
        df_top[i] = y + top
        df_width[i] = right-left
        df_height[i] = bottom-top
        
    df = tessdata.copy()
    df["left"] = df_left
    df["top"] = df_top
    df["width"] = df_width
    df["height"] = df_height
    return df


def im_to_data(im):
    tessdata = tess.image_to_data(im, config="--psm 1")
    return tessdata_to_df(tessdata)

## Cell Detection by OCR Bbox
Tesseract's `image_to_data` function returns all detected words and their bounding boxes.<br>
The following process determines if two tokens are part of the same cell based entirely on proximity.<br> 
After which, the bbox of each cell can be determined and finally the table grid
can be fit to the cell boundaries.

### Combine Tokens into Cells

In [None]:
def group_tokens(df, im=None, direction="Both", v=0):
    """
    """
    im_cp = im
    assert isinstance(direction, str)
    scale_y = 2
    median_font_height = df["height"].median()
    eps_multiplier = 1
    
    model = DBSCAN(
        eps=median_font_height * eps_multiplier,  # TODO: gridsearch the multiple for this
        min_samples=2  # used to connect two tokens into the same cell
    )
    
    # if v > 1:   
    #     # draw a green box around detected token groups
    #     for i, bbox in enumerate(df[['left', 'top', 'width', 'height']].values):
    #         x,y,w,h = bbox
    #         im = draw(im, bbox, 'blue')
    
    points = []
    tokens = []
    for i, bbox in enumerate(df[['left', 'top', 'width', 'height']].values):
        x,y,w,h = bbox
        
        y = y*scale_y
        tokens += [i]*8  # the next 8 points correspond to token[i]
        
        points.append([x, y+h/2])  # veritcally centered, left
        points.append([x+w, y+h/2])  # vertically centerd, right
        points.append([x+w/2, y])  # horizontally centered, top
        points.append([x+w/2, y+h])  # horizontally centerd, bottom
        points.append([x, y])  # top left
        points.append([x, y+h])  # bottom left
        points.append([x+w, y])  # top right
        points.append([x+w, y+h])  # bottom right
        
    if v > 1:
        assert im is not None
        
        # draw a red dot to show the points of each token
        im = np.array(im)  # ensure cv2 format
        for p in points:
            x,y = int(p[0]), int(p[1]/scale_y)
            im[y-2:y+2,x-2:x+2] = [255,0,0]
    
    groups = model.fit_predict(points)
    
    if v > 1:
        # draw a blue line between connected dots
        for i in range(max(groups)+1):
            last = None
            for j in [x for x in range(len(groups)) if groups[x] == i]:
                if last:
                    p1 = int(points[last][0]), int(points[last][1]/scale_y)
                    p2 = int(points[j][0]), int(points[j][1]/scale_y)
                    im = cv2.line(
                        im,
                        p1,
                        p2,
                        color=(0,0,255),
                        thickness=2,
                    )
                last = j
    
    # shrink df to predicted cells
    #
    # I conceptualize this as connecting constructing multiple chains
    # one link at a time in no particular order.
    # The DBSCAN model told us which links are connected, now we need
    # to construct all of the chains.
    
    # initialize the chain datastructure with all chains length 1
    # corresponding to the tokens in the original dataframe
    token_groups = defaultdict(list)
    for i in range(len(df)):
        token_groups[i].append(i)
    
    ## connecting the links
    ### iterate through the detected groups and connect chains together
    for i in range(max(groups)+1):
        idxs = list(set([tokens[x] for x in range(len(groups)) if groups[x] == i]))
        root = idxs[0]
        
        visited = set()
        if isinstance(token_groups[root], int):
            root = token_groups[root]
            
        assert isinstance(token_groups[root], list)
        
        for j in idxs[1:]:
            children = token_groups[j]
            if isinstance(children, int):
                children = token_groups[children]
            
            assert isinstance(children, list)
            root_children = token_groups[root] + children
            for c in children:
                token_groups[c] = root
            token_groups[root] = root_children
            
            
    rows = []
    for key, val in list(token_groups.items()):
        if isinstance(val,list):
            t_pos = list(zip(df["text"][val], df["left"][val], df["top"][val]))
            
            # sort tokens by y (at resolution of font height), break ties by x (at same res)
            t_pos.sort(
                key = lambda x: (
                    x[2] // median_font_height,
                    x[1] // median_font_height
                )
            )
            
            text = " ".join([x[0] for x in t_pos])
            
            left = min(df["left"][val])
            right = max(df["left"][val] + df["width"][val])
            width = right - left
            
            top = min(df["top"][val])
            bottom = max(df["top"][val] + df["height"][val])
            height = bottom - top
            
            rows.append({
                "left": left,
                "top": top,
                "width": width,
                "height": height,
                "text": text,
            })

    df = pd.DataFrame(rows)
    
    if v > 1:   
        # draw a green box around detected token groups
        for i, bbox in enumerate(df[['left', 'top', 'width', 'height']].values):
            x,y,w,h = bbox
            im = draw(im, bbox, 'green')
        
    if v > 1:
        display(Image.fromarray(im))
    
    return df

In [None]:
im = scale(X(96), 3)
df = im_to_data(im)
df = fit_bboxes_to_text(im, df)

In [None]:
df = group_tokens(df, im, v=2)

In [None]:
df

### Fit Grid to Cell Boundaries

In [None]:
def grid_detect(
    im,
    margin=0,
    v_thresh=0.1,
    h_thresh=0.2,
    v=0,
):
    """
    im <PIL.Image or imarray>: Image of a table
    
    margin <int> (0): Number of pixels to shrink each token bbox for the purpose
        of determining the grid.
        
    v_thresh <float> (0.1): Threshold for considering a row of pixels to be considered
        a vertical boundary in the grid. Higher values mean fewer rows.
        
    h_thresh <float> (0.1): Threshold for considering a column of pixels to be considered
        a horizontal boundary in the grid. Higher values mean fewer columns.
    
    v <int> (0): Verbose, 0-3. Show various stages of progress.
    """
    
    # use Tesseract to find all words
    df = im_to_data(im)
    df = fit_bboxes_to_text(im, df) # try to remove any extraneous whitespace from the bbox
    df = group_tokens(df, im, v=0) # spacially group bboxes into cells using DBSCAN
        
    # create a word mask from tesseract bboxes
    mask = np.zeros(im.size)
    for bbox in df[['left', 'top', 'width', 'height']].values:
        x,y,w,h = bbox
        mask[x+margin:x+w-margin,y+margin:y+h-margin]=1
        
    # mask=np.array(im).sum(axis=2).T  # use raw image instead of textboxes
    # mask=mask/mask.max()
    # mask-=1
    # mask = abs(mask)

    v_density = mask.sum(axis=0)/max(mask.sum(axis=0))
    h_density = mask.sum(axis=1)/max(mask.sum(axis=1))


    grid_y = []
    y=0
    out=True
    for y,val in enumerate(v_density):
        if val>v_thresh and out:
            out=False
            grid_y.append([y])
        if not val>v_thresh and not out:
            grid_y[-1].append(y)
            out = True
    
    grid_x = []
    x=0
    out=True
    for x,val in enumerate(h_density):
        if val>h_thresh and out:
            out = False
            grid_x.append([x])
        if not val>h_thresh and not out:
            grid_x[-1].append(x)
            out = True
    
    grid = []
    for y in grid_y:
        y1,y2 = y
        for x in grid_x:
            x1,x2 = x
            grid.append([x1,y1,x2-x1,y2-y1])
    
    if v > 0:
        for bbox in grid:
            im = draw(im, bbox, 'red')
    if v > 1:
        fig, axs = plt.subplots(2,2)
        axs[1,0].imshow(mask.T)
        axs[1,0].xaxis.set_visible(False)
        axs[1,0].yaxis.set_visible(False)
        axs[1,1].plot(v_density, np.arange(mask.shape[1],0,-1))
        axs[1,1].xaxis.set_visible(False)
        axs[1,1].yaxis.set_visible(False)
        axs[0,0].plot(h_density)
        axs[0,0].xaxis.set_visible(False)
        axs[0,0].yaxis.set_visible(False)
        axs[0,1].imshow(im)
        axs[0,1].xaxis.set_visible(False)
        axs[0,1].yaxis.set_visible(False)
        plt.show()
    
    if v > 2:
        plt.imshow(mask.T)
        plt.show()
    if v > 0:
        display(Image.fromarray(im))


In [None]:
f_num = 100

In [None]:
f_num

In [None]:
f_num += 1
grid_detect(scale(X(f_num), 3), v=3, margin=0)

In [None]:
grid_detect(scale(X(96), 3), v=3, margin=0)

In [None]:
grid_detect(scale(X(82), 2), v=3, margin=0)

In [None]:
grid_detect(scale(X(80), 2), v=3, margin=0)

In [None]:
grid_detect(scale(X(70), 2), v=3, margin=-3, v_thresh=0, h_thresh=0)

In [None]:
grid_detect(scale(X(48), 2), v=3, margin=-5, v_thresh=0, h_thresh=0)

In [None]:
grid_detect(scale(X(64), 3), v=3, margin=0)

## Cell Detection by Color
More or less the same idea as above except that the cell areas are much more noisy.<br>
In this process we rely entirely on the pixel color values to determine where the cells are (no OCR needed).

In [None]:
imarr1 = np.array(im)
for cell in Y(f_num)['html']['cells']:
    if 'bbox' in cell:
        # make sure bbox is not in x1,y1,x2,y2
        # that would explain the wierd sizes
        imarr1 = draw(imarr1, cell['bbox'], 'green')

imarr2 = np.array(scaled_im)
for bbox in df[['left', 'top', 'width', 'height']].values:
    
    imarr2 = draw(imarr2, bbox, 'red')

In [None]:
scale(Image.fromarray(imarr1),2)

In [None]:
Image.fromarray(imarr2)