# Document Layout Analysis

**Author:** Alan Meeson <alan@carefullycalculated.co.uk>

**Date:** 2023-07-09

Apply document layout analysis to break the paper down into it's components.

In [None]:
import os
import io
import fitz
import torch
import torchvision
import json
import numpy as np
import matplotlib.pyplot as plt
import layoutparser as lp
import pytesseract

from tqdm.notebook import tqdm
from typing import List, Dict, Set, Union
from pyprojroot import here
from PIL import Image

## Load a paper and display a page

In [None]:
paper_pdf = os.path.join(here(), 'data', 'paper.pdf')

In [None]:
pdf = fitz.open(paper_pdf)

In [None]:
page = pdf[0]
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))

mode = "RGBA" if pix.alpha else "RGB"
img = Image.frombytes(mode, [pix.width, pix.height], pix.samples)
img

## Declare & Apply the analysis pipeline

In [None]:
model = lp.Detectron2LayoutModel(
    config_path=os.path.join(here(), 'model', 'config.yaml'), 
    model_path=os.path.join(here(), 'model', 'model_final.pth'),
    extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], 
    label_map={0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
)

In [None]:
layout = model.detect(img)

In [None]:
lp.draw_box(img, layout, box_width=3)

## Extract Sections

### Identify Columns and sort blocks

In [None]:
widths = [bl.width for bl in layout]
plt.hist(widths)

In [None]:
num_cols = int(img.width // np.median(widths))
num_cols

In [None]:
column_width = np.median([bl.width for bl in layout])
layout.sort(key = lambda x: (
    x.coordinates[0] // column_width, # column number
    x.coordinates[1] # Y position
), inplace=True)

In [None]:
[(l.coordinates[0], l.coordinates[1]) for l in layout]

### Extract Text

In [None]:
text_blocks = lp.Layout([b for b in layout if b.type in {'Text', 'Title', 'List'}])
tb = text_blocks[0]
tb

In [None]:
text_image = tb.pad(15,5,15,5).crop_image(np.array(img))
plt.imshow(text_image)

In [None]:
text = pytesseract.image_to_string(text_image, config='--oem 3 --psm 6')
text
#probable_caption.set(text=text, inplace=True)

In [None]:
text.replace("-\n", '').replace('\n', ' ')

In [None]:
for block in layout:
    if block.type in {"Text", "Title", "List"}:
        text_image = block.pad(5,5,5,5).crop_image(np.array(img))
        text = pytesseract.image_to_string(text_image, config='--oem 3 --psm 6')
        text = text.replace("-\n", '').replace('\n', ' ').strip()
        block.set(text=text, inplace=True)

## Extract an image

In [None]:
figure_idxs = [idx for idx, b in enumerate(layout) if b.type=='Figure']
fi = figure_idxs[0]
fb = layout[fi]
fb

In [None]:
plt.imshow(fb.pad(15,15,15,15).crop_image(np.array(img)))

In [None]:
candidate_caption = layout[fi+1]
candidate_caption

In [None]:
#candidate_caption_text = candidate_caption.text.lower()
probably_figure_caption = any(map(candidate_caption.text.lower().startswith, ['fig', 'figure']))
probably_figure_caption

In [None]:
candidate_captions = [layout[idx] for idx in [fi+1, fi-1] if ((idx >= 0) & (idx < len(layout)))]
candidate_captions = [candidate_caption for candidate_caption in candidate_captions if any(map(candidate_caption.text.lower().startswith, ['fig', 'figure']))]

caption = candidate_captions[0].text if len(candidate_captions) > 0 else None
caption

### Extract a Table

In [None]:
table_idxs = [idx for idx, b in enumerate(layout) if b.type=='Table']
ti = table_idxs[0]
tb = layout[ti]
tb

In [None]:
plt.imshow(tb.pad(15,15,15,15).crop_image(np.array(img)))

In [None]:
candidate_captions = [layout[idx] for idx in [ti-1, ti+1] if ((idx >= 0) & (idx < len(layout)))]
candidate_captions = [candidate_caption for candidate_caption in candidate_captions if any(map(candidate_caption.text.lower().startswith, ['tab', 'table']))]

caption = candidate_captions[0].text if len(candidate_captions) > 0 else None
caption

## Bring it all together and construct a Document

In [None]:
def get_page_image(page: fitz.Page) -> Image:
    """Extract an image of a page from a PDF"""
    
    pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
    mode = "RGBA" if pix.alpha else "RGB"
    img = Image.frombytes(mode, [pix.width, pix.height], pix.samples)
    
    return img

In [None]:
def sort_layout_by_columns(layout: lp.Layout) -> lp.Layout:
    """Sorts the blocks in the layout by columns."""
    
    column_width = np.median([block.width for block in layout])
    return layout.sort(
        key = lambda block: (
            block.coordinates[0] // column_width, # column number
            block.coordinates[1] # Y position
        )
    )

In [None]:
def ocr_text_blocks(layout: lp.Layout, image: Image) -> lp.Layout:
    """Applies TesseractOCR to each text block"""
    
    for block in layout:
        if block.type in {"Text", "Title", "List"}:
            text_image = block.pad(15,5,15,5).crop_image(np.array(image))
            text = pytesseract.image_to_string(text_image, config='--oem 3 --psm 6')
            text = text.replace("-\n", '').replace('\n', ' ').strip()
            block.set(text=text, inplace=True)

    return layout

def extract_text_blocks(layout: lp.Layout, image: Image, page_num: int) -> List[Dict]:
    return [
        {
            'text': block.text,
            'type': block.type,
            'coordinates': {
                'x1': block.block.coordinates[0],
                'y1': block.block.coordinates[1],
                'x2': block.block.coordinates[2],
                'y2': block.block.coordinates[3]
            },
            'score': block.score,
            'page': page_num,
            'block_id': block_idx
        } for block_idx, block
        in enumerate(layout)
        if block.type in {'Text', 'Title', 'List'}
    ]

In [None]:
def _identify_caption(target_idx: lp.TextBlock, layout: lp.Layout, first_pass_offsets: List = None, candidates_start_with: Set = None) -> Union[str, None]:

    candidate_captions = None
    
    if first_pass_offsets:
        # first try the blocks immediately below and above the image, remembering to handle boundry cases
        num_blocks = len(layout)
        candidate_caption_idxs = [target_idx + offset for offset in first_pass_offsets]
        candidate_caption_idxs = [idx for idx in candidate_caption_idxs if ((idx >= 0) & (idx < num_blocks))]
        candidate_captions = [layout[idx] for idx in candidate_caption_idxs]
        
        if candidates_start_with:
            candidate_captions = [candidate_caption for candidate_caption in candidate_captions if any(map(candidate_caption.text.lower().startswith, candidates_start_with))]

    # Then failing that, try all text blocks by distance
    if not candidate_captions:
        # find those that have plausible text
        candidate_captions = [block for block in layout if block.type in {'Text', 'Title', 'List'}]
        
        if candidate_captions and candidates_start_with:
            candidate_captions = [candidate_caption for candidate_caption in candidate_captions if any(map(candidate_caption.text.lower().startswith, candidates_start_with))]

        if candidate_captions:
            target_block = layout[target_idx]
            candidate_captions.sort(
                key = lambda block: sum(
                    pow(a - b, 2) for a,b in zip(block.block.center, target_block.block.center)
                ), inplace = True
            )
        
    return candidate_captions[0].text if len(candidate_captions) > 0 else None

In [None]:
def extract_figure_blocks(layout: lp.Layout, image: Image, page_num: int) -> List[Dict]:

    figures = []
    
    figure_idxs = [idx for idx, b in enumerate(layout) if b.type=='Figure']

    for figure_idx in figure_idxs:
    
        figure_block = layout[figure_idx]
        figure_image = figure_block.pad(15,15,15,15).crop_image(np.array(image))
        caption = _identify_caption(figure_idx, layout, first_pass_offsets=[1, -1], candidates_start_with={'fig', 'figure'})

        figures.append({
            'image': figure_image,
            'caption': caption,
            'type': block.type,
            'coordinates': {
                'x1': block.block.coordinates[0],
                'y1': block.block.coordinates[1],
                'x2': block.block.coordinates[2],
                'y2': block.block.coordinates[3]
            },
            'score': block.score,
            'page': page_num,
            'block_id': figure_idx
        })

    return figures

In [None]:
def extract_table_blocks(layout: lp.Layout, image: Image, page_num: int) -> List[Dict]:

    tables = []
    
    tables_idxs = [idx for idx, b in enumerate(layout) if b.type=='Table']

    for table_idx in tables_idxs:
    
        table_block = layout[table_idx]
        table_image = table_block.pad(15,15,15,15).crop_image(np.array(image))
        table_data = None
        #table_data = table_parser.extract_table_data(table_image)
        caption = _identify_caption(table_idx, layout, first_pass_offsets=[-1, 1], candidates_start_with={'tab', 'table'})

        tables.append({
            'table': None, # TODO: parse image into pandas dataframe
            'image': table_image,
            'caption': caption,
            'type': block.type,
            'coordinates': {
                'x1': block.block.coordinates[0],
                'y1': block.block.coordinates[1],
                'x2': block.block.coordinates[2],
                'y2': block.block.coordinates[3]
            },
            'score': block.score,
            'page': page_num,
            'block_id': table_idx
        })

    return tables

In [None]:
paper_pdf = os.path.join(here(), 'data', 'paper.pdf')
pdf = fitz.open(paper_pdf)

In [None]:
def parse_pdf(pdf: fitz.Document) -> Dict:
    document = {
        'num_pages': len(pdf),
        'text': list(),
        'figures': list(),
        'tables': list()
    }
    
    for page_num, page in tqdm(enumerate(pdf), total = len(pdf)):
    
        image = get_page_image(page)
        layout = model.detect(image)
        layout = sort_layout_by_columns(layout)
        layout = ocr_text_blocks(layout, image)
    
        page_text_blocks = extract_text_blocks(layout, image, page_num)
        page_figure_blocks = extract_figure_blocks(layout, image, page_num)
        page_table_blocks = extract_table_blocks(layout, image, page_num)
    
        document['text'].extend(page_text_blocks)
        document['figures'].extend(page_figure_blocks)
        document['tables'].extend(page_table_blocks)

    return document

In [None]:
document = parse_pdf(pdf)

### Output to JSON and extras

In [None]:
def output_json(document: str, out_path: str):

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    # save figures as png
    num_figure_digits = int(np.ceil(np.log10(len(document['figures']))))
    for figure_num, figure in enumerate(document['figures']):
        figure_name = "figure%0*d.png" % (num_figure_digits, figure_num)
        
        figure_image = Image.fromarray(figure['image'])
        figure_image.save(os.path.join(out_path, figure_name), 'png')

        figure['image'] = figure_name

    # save tables as png/csv
    num_table_digits = int(np.ceil(np.log10(len(document['tables']))))
    for table_num, table in enumerate(document['tables']):
        table_name = "table%0*d" % (num_table_digits, table_num)
        table_image_name = f"{table_name}.png"
        table_csv_name = f"{table_name}.csv"
        
        table_image = Image.fromarray(table['image'])
        table_image.save(os.path.join(out_path, table_image_name), 'png')
        table['image'] = table_image_name

        if table.get('table', None):
            table['table'].to_csv(os.path.join(out_path, table_csv_name), index=False)
            table['table'] = table_csv_name

    # output the document as json
    document_json_file = os.path.join(out_path, 'document.json')
    with open(document_json_file, 'w') as fp:
        json.dump(document, fp)
    

In [None]:
output_json(document, '.')

In [None]:
image = Image.open(os.path.join(here(), 'test', 'table0.png'))

In [None]:
import sys
sys.path.insert(0, '..')

In [None]:
from article_parser.table_parser import extract_table_data

In [None]:
import PIL
import os
from pyprojroot import here
image = PIL.Image.open(os.path.join(here(), 'test', 'table0.png'))
image

In [None]:
df = extract_table_data(image)

In [None]:
df

In [None]:
df.to_csv('foo.csv', index=False, header=False)