In [1]:
import os
import cv2
import json
import yaml
import time
import pytz
import datetime
import argparse
import shutil
import torch
import numpy as np
import gc

from paddleocr import draw_ocr
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
from struct_eqtable import build_model

from modules.latex2png import tex2pil, zhtext2pil
from modules.extract_pdf import load_pdf_fitz
from modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from modules.self_modify import ModifiedPaddleOCR
from modules.post_process import get_croped_image, latex_rm_whitespace

In [2]:
def mfd_model_init(weight):
    mfd_model = YOLO(weight)
    return mfd_model


def mfr_model_init(weight_dir, device='cpu'):
    args = argparse.Namespace(cfg_path="modules/UniMERNet/configs/demo.yaml", options=None)
    cfg = Config(args)
    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
    cfg.config.model.model_config.model_name = weight_dir
    cfg.config.model.tokenizer_config.path = weight_dir
    task = tasks.setup_task(cfg)
    model = task.build_model(cfg)
    model = model.to(device)
    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
    return model, vis_processor

def layout_model_init(weight):
    model = Layoutlmv3_Predictor(weight)
    return model

def tr_model_init(weight, max_time, device='cuda'):
    tr_model = build_model(weight, max_new_tokens=4096, max_time=max_time)
    if device == 'cuda':
        tr_model = tr_model.cuda()
    return tr_model

args = argparse.Namespace(
    pdf='./test_data/',
    output='output',
    batch_size=1,
    vis=True,
    render=False
)
tz = pytz.timezone('Asia/Shanghai')
now = datetime.datetime.now(tz)
print(now.strftime('%Y-%m-%d %H:%M:%S'))
print('Started!')

2024-09-01 18:11:50
Started!


In [None]:
with open('configs/model_configs.yaml') as f:
    model_configs = yaml.load(f, Loader=yaml.FullLoader)
img_size = model_configs['model_args']['img_size']
conf_thres = model_configs['model_args']['conf_thres']
iou_thres = model_configs['model_args']['iou_thres']
device = model_configs['model_args']['device']
dpi = model_configs['model_args']['pdf_dpi']

tr_model = tr_model_init(model_configs['model_args']['tr_weight'], max_time=model_configs['model_args']['table_max_time'], device=device)
layout_model = layout_model_init(model_configs['model_args']['layout_weight'])
ocr_model = ModifiedPaddleOCR(show_log=True)
print(now.strftime('%Y-%m-%d %H:%M:%S'))
print('Model init done!')

In [4]:
start = time.time()
if os.path.isdir(args.pdf):
    all_pdfs = [os.path.join(args.pdf, name) for name in os.listdir(args.pdf)]
else:
    all_pdfs = [args.pdf]
print("total files:", len(all_pdfs))
for idx, single_pdf in enumerate(all_pdfs):
    try:
        img_list = load_pdf_fitz(single_pdf, dpi=dpi)
    except:
        img_list = None
        print("unexpected pdf file:", single_pdf)
    if img_list is None:
        continue
    print("pdf index:", idx, "pages:", len(img_list))
    # layout detection and formula detection
    doc_layout_result = []
    latex_filling_list = []
    mf_image_list = []
    for idx, image in enumerate(img_list):
        img_H, img_W = image.shape[0], image.shape[1]
        layout_res = layout_model(image, ignore_catids=[])
        layout_res['page_info'] = dict(
            page_no = idx,
            height = img_H,
            width = img_W
        )
        doc_layout_result.append(layout_res)

total files: 1
pdf index: 0 pages: 3


In [5]:
color_palette = [
                (255,64,255),(255,255,0),(0,255,255),(255,215,135),(215,0,95),(100,0,48),(0,175,0),(95,0,95),(175,95,0),(95,95,0),
                (95,95,255),(95,175,135),(215,95,0),(0,0,255),(0,255,0),(255,0,0),(0,95,215),(0,0,0),(0,0,0),(0,0,0)
            ]
id2names = ["title", "plain_text", "abandon", "figure", "figure_caption", "table", "table_caption", "table_footnote", 
            "isolate_formula", "formula_caption", " ", " ", " ", "inline_formula", "isolated_formula", "ocr_text"]
vis_pdf_result = []

for idx, image in enumerate(img_list):
    single_page_res = doc_layout_result[idx]['layout_dets']
    vis_img = Image.new('RGB', Image.fromarray(image).size, 'white') if args.render else Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    draw = ImageDraw.Draw(vis_img)
    for res in single_page_res:
        label = int(res['category_id'])
        if label > 15:     # categories that do not need visualize
            continue
        label_name = id2names[label]
        x_min, y_min = int(res['poly'][0]), int(res['poly'][1])
        x_max, y_max = int(res['poly'][4]), int(res['poly'][5])
        # we don't use the formule detection
        draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=color_palette[label], width=2)  # Increased outline width for better visibility
        fontText = ImageFont.load_default()  # Use default font due to potential font loading issues
        draw.text((x_min, y_min), label_name, fill=color_palette[label], font=fontText)  # Change 'color_palette' usage for fill
    
    width, height = vis_img.size
    width, height = int(0.75*width), int(0.75*height)
    vis_img = vis_img.resize((width, height))
    vis_pdf_result.append(vis_img)
        
    first_page = vis_pdf_result.pop(0)
    first_page.save(os.path.join(args.output, f'page_{idx}_layout.pdf'), 'PDF', resolution=100, save_all=True, append_images=vis_pdf_result)
    import json

    # Save single_page_res as a simple JSON file for each page
    output_dir = args.output
    os.makedirs(output_dir, exist_ok=True)
    for idx, single_page_res in enumerate(doc_layout_result):
        with open(os.path.join(output_dir, f'page_{idx}_layout.json'), 'w') as json_file:
            json.dump(single_page_res['layout_dets'], json_file, indent=4)

In [None]:
# OCR识别
for idx, image in enumerate(img_list):
    pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    single_page_res = doc_layout_result[idx]['layout_dets']
    single_page_mfdetrec_res = []
    ocr_result = []

    for res in single_page_res:
        xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
        xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
        crop_box = [xmin, ymin, xmax, ymax]
        cropped_img = Image.new('RGB', pil_img.size, 'white')
        cropped_img.paste(pil_img.crop(crop_box), crop_box)
        cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
        ocr_res = ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
        if ocr_res:
            res['ocr_result'] = []
            for box_ocr_res in ocr_res:
                p1, p2, p3, p4 = box_ocr_res[0]
                text, score = box_ocr_res[1]
                res['ocr_result'].append({
                    'poly': p1 + p2 + p3 + p4,
                    'score': round(score, 2),
                    'text': text,
                })
        
        output_dir = args.output
        os.makedirs(output_dir, exist_ok=True)
        
for idx, single_page_res in enumerate(doc_layout_result):
    with open(os.path.join(output_dir, f'page_{idx}_ocr_result.json'), 'w') as json_file:
        json.dump(single_page_res['layout_dets'], json_file, indent=4)

In [7]:
with torch.cuda.device('cuda'):  
	torch.cuda.empty_cache()  
	torch.cuda.ipc_collect()

In [None]:
# ocr and table recognition
for idx, image in enumerate(img_list):
    pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    single_page_res = doc_layout_result[idx]['layout_dets']
    single_page_mfdetrec_res = []
    for res in single_page_res:
        if int(res['category_id']) == 5: # table header, table, table footer 
            print('table here')
            xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
            xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
            crop_box = [xmin, ymin, xmax, ymax]
            cropped_img = pil_img.convert("RGB").crop(crop_box)
            start = time.time()
            with torch.no_grad():
                output = tr_model(cropped_img)
            end = time.time()
            if (end-start) > model_configs['model_args']['table_max_time']:
                res["timeout"] = True
            res["latex_result"] = output[0]
    
for idx, single_page_res in enumerate(doc_layout_result):
    with open(os.path.join(output_dir, f'page_{idx}_orc_table_result.json'), 'w') as json_file:
        json.dump(single_page_res['layout_dets'], json_file, indent=4)

In [9]:
for idx, single_page_res in enumerate(doc_layout_result):
    for res_index, res in enumerate(single_page_res['layout_dets']):
        if 'latex_result' in res and int(res['category_id']) == 5:
            latex = res['latex_result']
            print(latex)
            latex_image = tex2pil(latex)[0]
            latex_image.save(os.path.join(output_dir,f'page_{idx}_latex_image_{res_index}.jpg'))

\begin{tabular}{@{}lccc@{}}\toprule\textbf{Variable} & \textbf{Pre} & \textbf{ during} & \textbf{P value} \\\midrule \textbf{Klocsiories} & \textbf{$2185 \pm 94$} & \textbf{$172 \pm 85$} & \textbf{00005} \\\textbf{Prouen(g)} & \textbf{$92 \pm 6$} & \textbf{$62 \pm 5$} & \textbf{00003} \\\textbf{Prouen(R){\tiny 图} } & \textbf{17 \pm 0} & \textbf{13 \pm 0} & \textbf{0004} \\\textbf{Carbohydrate(g)} & \textbf{287 \pm I4} & \textbf{269 \pm 17} & \textbf{041} \\\textbf{Carbohydrate(g)} & \textbf{33 \pm 0} & \textbf{62 \pm 0} & \textbf{00002} \\\textbf{Faber(g)} & \textbf{26 \pm 2} & \textbf{40 \pm 3} & \textbf{$< 00001$} \\\textbf{Sugar(g)} & \textbf{96 \pm 7} & \textbf{88 \pm 6} & \textbf{037} \\\textbf{Fat(g)} & \textbf{74 \pm 5} & \textbf{54 \pm 4} & \textbf{0003} \\\textbf{Fat(G)} & \textbf{30 \pm 0} & \textbf{27 \pm 0} & \textbf{020} \\\textbf{SaturatedFat(g)} & \textbf{24 \pm 2} & \textbf{9 \pm 1} & \textbf{$< 00001$} \\\textbf{MonoursarvatedFat(g)} & \textbf{14 \pm 2} & \textbf{14 \p

In [165]:
def process_ocr_result(ocr_reuslt):
    ocr_result_ordered = []
    for res in ocr_result:
        xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
        xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
        x_avg = (xmin + xmax)/2
        y_avg = (ymin + ymax)/2
        text = res['text']
        ocr_result_ordered.append((x_avg,y_avg,text))
    return ocr_result_ordered

doc_llm= []
for idx, page in enumerate(doc_layout_result):
    res_llm = []
    layout = page['layout_dets']
    for res in layout:
        layout_llm = {}
        xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
        xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
        x_avg = (xmin + xmax)/2
        y_avg = (ymin + ymax)/2
        layout_llm['center_position'] = (x_avg,y_avg)
        layout_llm['box_position'] = (xmin,ymin,xmax,ymax)

        category_id = res['category_id']
        layout_llm['category'] = id2names[category_id]
        if 'ocr_result' in res :
            ocr_result = res['ocr_result']
            layout_llm['ocr_result'] = process_ocr_result(ocr_result)
        if 'latex_result' in res:
            latex_result = res['latex_result']
            layout_llm['latex_result'] = latex_result
        res_llm.append(layout_llm)
    doc_llm.append(res_llm)

# for idx, single_page_res in enumerate(doc_llm):
#     with open(os.path.join(output_dir, f'page_{idx}_llm_result.json'), 'w') as json_file:
#         json.dump(single_page_res, json_file, indent=4)

# with open('/output/doc_llm.json', 'w') as f:
#     json.dump(doc_llm, f, ensure_ascii=False, indent=4)

In [177]:
def double_column_sort(data):
    # 分组
    group1 = [item for item in data if item['center_position'][0] < 800]
    group2 = [item for item in data if item['center_position'][0] >= 800]

    # 按照 center_position 的第二个元素进行排序
    group1_sorted = sorted(group1, key=lambda x: x['center_position'][1])
    group2_sorted = sorted(group2, key=lambda x: x['center_position'][1])

    # 合并排序后的两组
    return group1_sorted + group2_sorted

import copy
for idx,single_page_llm in enumerate(doc_llm):
    if idx == 0:
        continue
    DocAnalysePrompt = """你是一个文本分析大师，你需要将一个杂乱的文本内容转化为有条理的markdown结构
    我们规定文本的左上角为0,0 其中x表示从左往右依次增大，y表示从上到下依次增大

    我们会输入一个list，其中的item都是的下面的结构：
    {'center_position': (x_avg,y_avg),
    'box_position': (xmin,ymin,xmax,ymax),
    'category': 'category',包含文本，标题等
    'ocr_result': [(x_avg, y_avg, ocr_result)]
    'latex_result' : 'latex'}
    ocr_result是一个list，里面是一个三元组(x_avg, y_avg, ocr_result)，里面ocr_result就是识别到的文本内容。
    这里面的ocr_result会有重叠的部分，你要根据对应的x_avg y_avg判断这些内容是不是重复的，并且忽略掉重复的部分

    你需要分析这个文本的结构，例如：是单栏，双栏或者是其他结构
    """
    from pydantic import BaseModel
    from enum import Enum

    class ColumnType(str, Enum):
        single_column = "single_column"
        double_column = "double_column"
        others = "others"
        
    class DocAnalyseResponseFormat(BaseModel):
        very_short_analysis : str
        ColumnType : ColumnType

    from utils import openai_wrapper
    DocAnalysePage = [{'center_position': item['center_position'], 'box_position': item['box_position'], 'category': item['category']} for item in single_page_llm]

    DocAnalyse_result =  openai_wrapper(DocAnalysePrompt,DocAnalysePage,DocAnalyseResponseFormat,model="gpt-4o-2024-08-06")

    # reduce the amount of the pages:
    if DocAnalyse_result.ColumnType is ColumnType.single_column:
        page = copy.deepcopy(single_page_llm)
        page = sorted(page, key=lambda x: x['center_position'][1])

        name2latex = {}
        latex_name_list = []
        for index,item in enumerate(page):
            if 'box_position' in item : del item['box_position']
            if 'center_position' in item : del item['center_position']
            if 'latex_result' in item:
                name2latex[f"latex_{index}"] = item['latex_result']
                latex_name_list.append(f"latex_{index}")
                item['latex_result'] = f"latex_{index}"
                if 'ocr_result' in item:
                    del item['ocr_result']
            else :
                if "ocr_result" in item : item['ocr_result'] = sorted( item['ocr_result'], key=lambda x: x[1])
    elif DocAnalyse_result.ColumnType is ColumnType.double_column:
        page = copy.deepcopy(single_page_llm)
        page = double_column_sort(page)

        name2latex = {}
        latex_name_list = []
        for index,item in enumerate(page):
            if 'box_position' in item : del item['box_position']
            # if 'center_position' in item : del item['center_position']
            if 'latex_result' in item:
                name2latex[f"latex_{index}"] = item['latex_result']
                latex_name_list.append(f"latex_{index}")
                item['latex_result'] = f"latex_{index}"
                if 'ocr_result' in item:
                    del item['ocr_result']
            else :
                if "ocr_result" in item : item['ocr_result'] = sorted( item['ocr_result'], key=lambda x: x[1])
    
    else:
        break


    print(name2latex)
    DocWritePrompt = """你是一个文本格式化机，你需要将一个杂乱的文本内容转化为有条理的markdown结构，你不会遗漏原文中的任何信息
我们规定文本的左上角为0,0,我们使用(x,y)来表示位置 其中x表示从左往右依次增大，y表示从上到下依次增大

我们会输入一个list，其中的item都是的下面的结构：
{'category': 'category',包含文本，标题等
'ocr_result': [(x_avg, y_avg, ocr_result)]
'latex_result' : 'latex_i'}
ocr_result是一个list，里面是一个三元组(x_avg, y_avg, ocr_result)，里面ocr_result就是识别到的文本内容。在这里面你也是要根据对应的y_avg去判断元素所应该在的位置
这里面的ocr_result会有重叠的部分，你要根据对应的x_avg y_avg判断这些内容是不是重复的，并且忽略掉重复的部分
你要注意其中我们对于其中的图表使用了{latex_list}去进行替代，你要在对应的位置使用####latex_i####来代替对应的图表。
请你注意我们的category识别可能是会有问题的，你要对这里面的明显错误去进行纠正
你要将文本中提到的内容都严格地完成地返回，不要遗漏任何内容，不要省略任何内容，

错误示例：
没有将内容都返回，反而给出了省略号，同时返回了不存在的图表
# title
....
####latex_i####
# next title
没有严格按照原文本的内容返回，而是添加了其他的内容
正确示例：
# title
the text from the input 
####latex_1#### 要展示的图表
请你务必严格遵守上面的规则，否则我会失去这份工作
"""
    from pydantic import BaseModel
        
    class DocWriteResponseFormat(BaseModel):
        markdown : str
        
    DocWritePrompt = DocWritePrompt.replace("{latex_list}", str(latex_name_list))

    result =  openai_wrapper(DocWritePrompt,f"input text : {page}",DocWriteResponseFormat,model="gpt-4o-2024-08-06")

    pattern = r"####(latex_\d+)####"
    markdown_result = result.markdown

    def post_process_replace(pattern,replace_dict,markdown_result):
        import re
        filtered_result = re.sub(pattern, lambda match: replace_dict.get(match.group(1), match.group(0)), markdown_result)
        return filtered_result
    final_result = post_process_replace(pattern,name2latex,markdown_result)
    print(final_result)

    with open(os.path.join(output_dir, f'page_{idx}_openai.md'), 'w') as file:
        file.write(final_result)


{'latex_2': "\\begin{tabular}{@{}lccc@{}}\\hline \\hline \\textbf{\\tiny 标志} & \\phantom{ab} & \\textbf{\\tiny In} & \\phantom{ab} & \\textbf{\\tiny Post} \\\\\\hline \\textbf{\\tiny Onosteol(mgdt)''} && \\textbf{\\tiny 1/11\\,±\\,46} && \\textbf{\\tiny 1387\\,±\\,44} \\\\\\textbf{\\tiny 70lycerodes(mgdt)''} && \\textbf{\\tiny 85.1\\,±\\,48} && \\textbf{\\tiny 753\\,±\\,36} \\\\\\textbf{\\tiny Hol-C(mgdt)''} && \\textbf{\\tiny 556\\,±\\,23} && \\textbf{\\tiny 476\\,±\\,22} \\\\\\textbf{\\tiny MDL-C(mgdt)''} && \\textbf{\\tiny 170\\,±\\,10} && \\textbf{\\tiny 180\\,±\\,07} \\\\\\textbf{\\tiny DDLC(mgdt)''} && \\textbf{\\tiny 984\\,±\\,39} && \\textbf{\\tiny 761\\,±\\,35} \\\\\\textbf{\\tiny Fina+DL-C} && \\textbf{\\tiny 13\\,±\\,0]} && \\textbf{\\tiny 11\\,±\\,01} \\\\\\hline \\hline \\end{tabular}", 'latex_6': '\\begin{tabular}{@{}lccc@{}}\\toprule\\textbf{Variable} & \\textbf{Pre} & \\textbf{ during} & \\textbf{P value} \\\\\\midrule \\textbf{Klocsiories} & \\textbf{$2185 \\pm 94$} & 