In [12]:
import io
import numpy as np
import pandas as pd
import cv2
import fitz
import pyocr
from PIL import Image
from matplotlib import pyplot as plt

In [13]:
doc = fitz.open("resource/ANA領収書.pdf")
page = doc[0]

def get_drawing_pix(page):
    paths = page.get_drawings()  # extract existing drawings
    outpdf = fitz.open()
    outpage = outpdf.new_page(width=page.rect.width, height=page.rect.height)
    shape = outpage.new_shape()  # make a drawing canvas for the output page

    for path in paths:
        for item in path["items"]:  # these are the draw commands
            if item[0] == "l":  # line
                shape.draw_line(item[1], item[2])
            elif item[0] == "re":  # rectangle
                shape.draw_rect(item[1])
            elif item[0] == "qu":  # quad
                shape.draw_quad(item[1])
            elif item[0] == "c":  # curve
                shape.draw_bezier(item[1], item[2], item[3], item[4])
            else:
                raise ValueError("unhandled drawing", item)

        shape.finish(
            fill=path["fill"],  # fill color
            color=path["color"],  # line color
            dashes=path["dashes"],  # line dashing
            even_odd=path.get("even_odd", True),  # control color of overlaps
            closePath=path["closePath"],  # whether to connect last and first point
            lineJoin=path["lineJoin"],  # how line joins should look like
            lineCap=max(path["lineCap"]),  # how line ends should look like
            width=path["width"],  # line width
            stroke_opacity=path.get("stroke_opacity", 1),  # same value for both
            fill_opacity=path.get("fill_opacity", 1),  # opacity parameters
            )
    shape.commit()
    return outpage.get_pixmap()

def binarize_image(img, threshold=210):
    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    ret, img_bin = cv2.threshold(img_gray, threshold, 255, cv2.THRESH_BINARY_INV)
    return img_bin

pixmap = get_drawing_pix(page)
img1 = np.asarray(Image.open(io.BytesIO(pixmap.tobytes())))
img1_bin = binarize_image(img1)

plt.imshow(img1_bin, cmap="gray")
plt.show()

<IPython.core.display.Javascript object>

In [14]:
import re
from functools import lru_cache

class rdict(dict):
    '''Regex dictionary.
    
    This dictionary can take a regex pattern as a search key.
    '''
    def __getitem__(self, key):
        if isinstance(key, re.Pattern):
            pattern = key
            r = [v for k,v in self.items() if rdict._key_search(pattern, k)]
            if len(r) > 0:
                return r
            else:
                raise KeyError(key)
        else:
            return super().__getitem__(key)
    
    def rget(self, *keys):
        patterns = [rdict._key_compile(key) for key in keys]
        r = [v for k,v in self.items() 
             if all([rdict._key_search(p, k) for p in patterns])]
        if len(r) > 0:
            return r
        else:
            raise KeyError(key)
    
    def __contains__(self, key):
        if isinstance(key, re.Pattern):
            pattern = key
            m = [rdict._key_search(pattern, k) for k in self.keys()]
            return True if any(m) else False
        else:
            return super().__contains__(key)
    
    def is_rin(self, *keys):
        patterns = [rdict._key_compile(key) for key in keys]
        m = [all(rdict._key_search(p, k) for p in patterns) 
             for k in self.keys()]
        return True if any(m) else False
    
    @staticmethod
    @lru_cache(maxsize=128)
    def _key_compile(key):
        return re.compile(key)
    
    @staticmethod
    def _key_search(pattern, key):
        if isinstance(key, str):
            return re.search(pattern, key)
        elif isinstance(key, tuple):
            m = [pattern.search(k) for k in key]
            return m if any(m) else None

In [15]:
d = rdict()
d["foo"] = "nar"
d["baz"] = 100
d[("foo", "bar", "quot")] = 12345
d[340] = 1010

print(d)
p = re.compile("fo")
print(d[p])
print(d[340])
print(d.rget("ba*"))
print("fo" in d)
print(p in d)
print(d.is_rin("quot"))
print(d.rget("foo", "quo"))

{'foo': 'nar', 'baz': 100, ('foo', 'bar', 'quot'): 12345, 340: 1010}
['nar', 12345]
1010
[100, 12345]
False
True
True
[12345]


In [16]:
from copy import deepcopy

class Cell():
    def __init__(self, bbox, content, row_group=None, col_group=None):
        self.bbox = bbox
        self.content = content
        if row_group is None:
            self.row_group = set()
        if col_group is None:
            self.col_group = set()
    
    def __getitem__(self, key):
        return self.__getattribute__(key)
    
    def __setitem__(self, key, value):
        self.__setattr__(key, value)
    
    def __repr__(self):
        return f"{self.bbox}: {self.content}: {self.row_group}"
    
    def width(self):
        return self.bbox[2] - self.bbox[0]
    
    def height(self):
        return self.bbox[3] - self.bbox[1]


class Table():
    def __init__(self, cells=None, relation_infer_func=None, **kwargs):
        self.cells = None
        self.cell_matrix = None
        self.shape = None
        if cells is not None:
            self.load_cells(cells, relation_infer_func)
    
    def __getitem__(self, idx):
        return self.cell_matrix[idx]
    
    @staticmethod
    def _group_cells(cells, relation_infer_func=None, **kwargs):
        '''Label cells which row and column they are.
        
        Label cells which row and column they are.
        
        Parameters
        ----------
        cells: list[Cell]
            list of Cell which represents the bbox of each cell
        relation_infer_func: callable
            function which infers relation between a cell and the other and shall return Enum CellRelation
        
        Return
        ------
        cg_list: list[Cell]
            list of Cell whose row_group and col_group are set
        '''
        if relation_infer_func is None:
            relation_infer_func = lambda c1,c2: CellRelation.NONE
        
        row_count = col_count = 0
        _cells = deepcopy(cells)
        
        # group cells by row
        processed_cells = []
        _cells = sorted(_cells, key=lambda cell: cell.height())
        for cur_cell in _cells:
            for p_cell in processed_cells:
                relation = relation_infer_func(cur_cell, p_cell, **kwargs)
                if relation in (CellRelation.SAME_ROW, CellRelation.PARTIALLY_SAME_ROW):
                    cur_cell["row_group"].update(p_cell["row_group"])
            if len(cur_cell["row_group"]) == 0:
                cur_cell["row_group"].add(row_count)
                row_count += 1
            processed_cells.append(cur_cell)
        
        # group cells by column
        processed_cells.clear()
        _cells = sorted(_cells, key=lambda cell: cell.width())
        for cur_cell in _cells:
            for p_cell in processed_cells:
                relation = relation_infer_func(cur_cell, p_cell, **kwargs)
                if relation in (CellRelation.SAME_COL, CellRelation.PARTIALLY_SAME_COL):
                    cur_cell["col_group"].update(p_cell["col_group"])
            if len(cur_cell["col_group"]) == 0:
                cur_cell["col_group"].add(col_count)
                col_count += 1
            processed_cells.append(cur_cell)
        
        return _cells
    
    @staticmethod
    def _sort_cellgroup_index(cells):
        '''Sort row group and column group of the cell in ordinal order
        
        Sort row group and column group of the cell in ordinal order
        
        Parameter
        ---------
        cg_list: list[dict{'bbox', 'row_group', 'col_group'}]
            list of dict which has cell's bbox, row group and column group
        
        Return
        ------
        None
        '''
        row_group_representatives = []
        col_group_representatives = []
        for cell in cells:
            row_group = next(iter(cell["row_group"]))
            col_group = next(iter(cell["col_group"]))
            if len(cell["row_group"]) == 1 and \
               row_group not in map(lambda cell: cell["row_group"], row_group_representatives):
                row_group_representatives.append( {"bbox": cell["bbox"], "row_group": row_group} )
            if len(cell["col_group"]) == 1 and \
               col_group not in map(lambda cell: cell["col_group"], col_group_representatives):
                col_group_representatives.append( {"bbox": cell["bbox"], "col_group": col_group} )
        
        row_group_representatives.sort(key=lambda cell: cell["bbox"][1])
        col_group_representatives.sort(key=lambda cell: cell["bbox"][0])
        row_index_map = dict( [(cell["row_group"], idx) for idx, cell in enumerate(row_group_representatives)] )
        col_index_map = dict( [(cell["col_group"], idx) for idx, cell in enumerate(col_group_representatives)] )
        
        for cell in cells:
            cell["row_group"] = { row_index_map[orig_idx] for orig_idx in cell["row_group"] }
            cell["col_group"] = { col_index_map[orig_idx] for orig_idx in cell["col_group"] }
        
    
    def load_cells(self, cells, relation_infer_func=None, **kwargs):
        '''Load cells to the table.
        
        Load cells to the table.
        
        Parameters
        ----------
        cells: list[Cell]
            list of dict which represents a bbox (x1,y1,x2,y2) and a content of the cell
        
        Returns
        -------
        shape of the table constructed from the cells
        '''
        self.cells = Table._group_cells(cells, relation_infer_func, **kwargs)
        Table._sort_cellgroup_index(self.cells)
        
        num_row = max([max(cg["row_group"]) for cg in self.cells]) + 1
        num_col = max([max(cg["col_group"]) for cg in self.cells]) + 1
        self.cell_matrix = np.full((num_row, num_col), None, dtype=object)
        
        for cell in self.cells:
            for row in cell["row_group"]:
                for col in cell["col_group"]:
                    self.cell_matrix[row,col] = cell
        self.shape = (num_row, num_col)
        return (num_row, num_col)
        
    
    def is_row_merged_cell(self, i, j):
        return True if len(self[i,j]["row_group"]) > 2 else False
    
    def is_col_merged_cell(self, i, j):
        return True if len(self[i,j]["col_group"]) > 2 else False

In [17]:
from enum import Enum

class CellRelation(Enum):
    NONE = 0
    SAME_CELL = 1
    SAME_ROW = 2
    SAME_COL = 3
    PARTIALLY_SAME_ROW = 4
    PARTIALLY_SAME_COL = 5
    INSIDE_CELL = 6
    #R_PARTIALLY_SAME_ROW = 4    # overlaps partially and horizontally with the right cell bigger
    #L_PARTIALLY_SAME_ROW = 5    # overlaps partially and horizontally with the left cell bigger
    #T_PARTIALLY_SAME_COL = 6    # overlaps partially and vertically with the top cell bigger
    #B_PARTIALLY_SAME_COL = 7    # overlaps partially and vertically with the bottom cell bigger
    #INSIDE_CELL = 8
    

def infer_cell_relation(cell1, cell2, threshold=0.9):
    sx1, sy1, ex1, ey1 = cell1["bbox"]
    sx2, sy2, ex2, ey2 = cell2["bbox"]
    w1, h1 = ex1 - sx1, ey1 - sy1
    w2, h2 = ex2 - sx2, ey2 - sy2
    overlap_x = min(w1, w2, ex1 - sx2, ex2 - sx1)
    overlap_y = min(h1, h2, ey1 - sy2, ey2 - sy1)
    
    row_relation = "none"
    col_relation = "none"
    if overlap_x > min(w1,w2)*threshold:
        # the two cell overlaps vertically
        if overlap_x > w1*threshold and overlap_x > w2*threshold:
            col_relation =  "same"
        else:
            col_relation = "partial"
    if overlap_y > min(h1,h2)*threshold:
        # the two cell overlaps horizontally
        if overlap_y > h1*threshold and overlap_y > h2*threshold:
            row_relation = "same"
        else:
            row_relation = "partial"
    
    if row_relation == "same" and col_relation == "same":
        return CellRelation.SAME_CELL
    elif row_relation != "none" and col_relation != "none":
        return CellRelation.INSIDE_CELL
    elif row_relation == "same":
        return CellRelation.SAME_ROW
    elif row_relation == "partial":
        return CellRelation.PARTIALLY_SAME_ROW
    elif col_relation == "same":
        return CellRelation.SAME_COL
    elif col_relation == "partial":
        return CellRelation.PARTIALLY_SAME_COL
    else:
        return CellRelation.NONE


def get_bbox(contour):
    min_x = contour[:,0,0].min()
    max_x = contour[:,0,0].max()
    min_y = contour[:,0,1].min()
    max_y = contour[:,0,1].max()
    return (min_x, min_y, max_x, max_y)

In [18]:
from collections import defaultdict

contours1, hierarchy1 = cv2.findContours(img1_bin, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
tables_raw = defaultdict(list)
img1_cp = np.zeros_like(img1_bin)
for i, cont in enumerate(contours1):
    if hierarchy1[0][i][2] == -1 and hierarchy1[0][i][3] != -1:    # the most inside cell
        bbox = get_bbox(cont)
        tables_raw[hierarchy1[0][i][3]].append(Cell(bbox, ""))
        cv2.drawContours(img1_cp, contours1, i, 255, 1)
plt.imshow(img1_cp, cmap="gray")
plt.show()

In [19]:
tables = []
for cells in tables_raw.values():
    for cell in cells:
        cell.content = page.get_text(clip=cell.bbox)
    table = Table(cells, infer_cell_relation)
    tables.append(table)

print(table[1,0])

(702, 337, 800, 386): 照会番号
: {1}


In [20]:
def plot_square(ax=plt, bbox=(0,0,1,1), label=None):
    p1 = [bbox[0],bbox[1]]
    p2 = [bbox[2],bbox[1]]
    p3 = [bbox[2],bbox[3]]
    p4 = [bbox[0],bbox[3]]
    ps = np.array([p1,p2,p3,p4,p1])
    ax.plot(ps[:,0], ps[:,1])
    if label is not None:
        cx = (bbox[0] + bbox[2])/2
        cy = (bbox[1] + bbox[3])/2
        ax.text(cx, cy, label, ha="center", va="center", fontsize=6, fontname="MS Gothic")
    return

%matplotlib notebook
fig, ax = plt.subplots(figsize=(5,4))
for table in tables:
    nrow, ncol = table.shape
    for i in range(nrow):
        for j in range(ncol):
            cell = table[i,j]
            plot_square(ax, cell["bbox"], f"{cell['row_group']},{cell['col_group']}")
ax.set_xlim(0,1364)
ax.set_ylim(1065,0)
plt.show()

<IPython.core.display.Javascript object>

In [62]:
def logarithm_threshold(bbox_whs):
    non_zero_whs = bbox_whs[np.all(bbox_whs > 0, axis=1)]
    bbox_whs_log = np.log2(non_zero_whs)
    widths = bbox_whs_log[:,0].copy()
    heights = bbox_whs_log[:,1].copy()
    widths.sort()
    heights.sort()
    w_diffs = np.diff(widths)
    h_diffs = np.diff(heights)
    w_threshold = np.exp2( widths[w_diffs.argmax()] + w_diffs.max()/2 )
    h_threshold = np.exp2( heights[h_diffs.argmax()] + h_diffs.max()/2 )
    
    return (w_threshold, h_threshold)

def remove_characters(img_bin):
    contours, hierarchy = cv2.findContours(img_bin, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
    contours_ext = tuple( (cont for i,cont in enumerate(contours) if hierarchy[0][i][3] == -1) )
    bbox_whs = np.array([(bbox[2]-bbox[0], bbox[3]-bbox[1]) for bbox in map(get_bbox, contours_ext)])
    
    w_threshold = logarithm_threshold(bbox_whs)[0]
    w_threshold = 200
    
    img_cp = img_bin.copy()
    for i, cont in enumerate(contours_ext):
        # remove characters (which should be small figures) by filling white
        if bbox_whs[i,0] < w_threshold:
            cv2.drawContours(img_cp, contours_ext, i, 0, -1)
    return img_cp

def emboss_characters(img_bin):
    img_cp = img_bin.copy()
    img_nochars = remove_characters(img_bin)
    img_nochars = cv2.dilate(img_nochars, kernel=np.ones((6,6), np.uint8))
    img_chars = np.where(img_nochars==255, 0, img_cp)
    return img_chars


zoom = fitz.Matrix(5, 5)
img2_pil = Image.open(io.BytesIO(page.get_pixmap(matrix=zoom).tobytes()))
img2 = np.asarray(img2_pil)
img2_bin = binarize_image(img2)
img_nochars = remove_characters(img2_bin)
img_chars = emboss_characters(img2_bin)

plt.subplot(121)
plt.title("Remove characters")
plt.imshow(img_nochars, cmap="gray")
plt.subplot(122)
plt.title("Emboss characters")
#plt.imshow(img_chars, cmap="gray")
plt.imshow(img_nochars, cmap="gray")
plt.show()

<IPython.core.display.Javascript object>

In [63]:
import defusedxml.ElementTree as ET

In [64]:
def parse_position(position, frame, root, max_ref_depth=10):
    if max_ref_depth < 0:
        raise Exception("Limit of depth exceeded while parsing position!")
    
    xys = position.split(",")
    for i,xy in enumerate(xys):
        new_xy = -1
        if xy.isdigit():
            new_xy = int(xy)
        elif xy == "*":
            new_xy = int(frame[i])
        else:
            pattern_id = xy.split(".")[0][1:]
            xy_idx = ["x1","y1","x2","y2"].index(xy.split(".")[1][1])
            for pattern in root:
                if pattern.get("id") == pattern_id:
                    new_xy = parse_position(pattern.get("position"), frame, root, max_ref_depth-1)[xy_idx]
        if new_xy == -1:
            raise Exception("Invalid position!")
        else:
            xys[i] = new_xy
    return tuple(xys)
            
            

def match_score(pattern_file, page, img_page):
    total = 0
    match = 0
    
    tree = ET.parse(pattern_file)
    root = tree.getroot()
    for pattern in root:
        position = parse_position(patter.get("position"), tuple(page.rect), root)
        if pattern.tag == "textbox" and pattern.get("action") == "check_format":
            total += 1
            exracted_text = page.get_text(clip=position)
            if extracted_text == pattern.text:
                match += 1
    return match/total

In [65]:
p = "#12.y1"
m = re.match("#(\d+)\.", p)
print(m)

<re.Match object; span=(0, 4), match='#12.'>


In [56]:
tuple(page.rect)

(0.0, 0.0, 1364.0, 1065.0)