From a16afd5eb0bd72e827eb599185a19c0e61a6ba6a Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sat, 8 Mar 2025 01:21:31 +0800 Subject: [PATCH 1/8] feat: trans to rapidTable code style --- demo_lineless.py | 51 +- demo_wired.py | 62 +- lineless_table_rec/main.py | 185 ++-- lineless_table_rec/process.py | 816 +++++++++--------- lineless_table_rec/table_structure_lore.py | 92 ++ lineless_table_rec/utils.py | 180 ---- lineless_table_rec/utils/__init__.py | 0 lineless_table_rec/utils/download_model.py | 67 ++ lineless_table_rec/utils/logger.py | 21 + lineless_table_rec/utils/utils.py | 477 ++++++++++ .../utils/utils_table_lore_rec.py | 408 +++++++++ .../{ => utils}/utils_table_recover.py | 0 tests/test_lineless_table_rec.py | 2 +- tests/test_wired_table_line_util.py | 2 +- tests/test_wired_table_rec.py | 2 +- wired_table_rec/__init__.py | 2 +- wired_table_rec/main.py | 124 ++- ...py => table_structure_cycle_center_net.py} | 12 +- ...ne_rec_plus.py => table_structure_unet.py} | 12 +- wired_table_rec/utils.py | 397 --------- wired_table_rec/utils/__init__.py | 0 wired_table_rec/utils/download_model.py | 67 ++ wired_table_rec/utils/logger.py | 21 + wired_table_rec/utils/utils.py | 694 +++++++++++++++ .../{ => utils}/utils_table_line_rec.py | 0 .../{ => utils}/utils_table_recover.py | 87 -- 26 files changed, 2501 insertions(+), 1280 deletions(-) create mode 100644 lineless_table_rec/table_structure_lore.py delete mode 100644 lineless_table_rec/utils.py create mode 100644 lineless_table_rec/utils/__init__.py create mode 100644 lineless_table_rec/utils/download_model.py create mode 100644 lineless_table_rec/utils/logger.py create mode 100644 lineless_table_rec/utils/utils.py create mode 100644 lineless_table_rec/utils/utils_table_lore_rec.py rename lineless_table_rec/{ => utils}/utils_table_recover.py (100%) rename wired_table_rec/{table_line_rec.py => table_structure_cycle_center_net.py} (93%) rename wired_table_rec/{table_line_rec_plus.py => table_structure_unet.py} (96%) delete mode 100644 wired_table_rec/utils.py create mode 100644 wired_table_rec/utils/__init__.py create mode 100644 wired_table_rec/utils/download_model.py create mode 100644 wired_table_rec/utils/logger.py create mode 100644 wired_table_rec/utils/utils.py rename wired_table_rec/{ => utils}/utils_table_line_rec.py (100%) rename wired_table_rec/{ => utils}/utils_table_recover.py (87%) diff --git a/demo_lineless.py b/demo_lineless.py index 06725ee..cb8ddef 100644 --- a/demo_lineless.py +++ b/demo_lineless.py @@ -3,30 +3,45 @@ # @Contact: liekkaskono@163.com from pathlib import Path +from rapidocr_onnxruntime import RapidOCR + from lineless_table_rec import LinelessTableRecognition -from lineless_table_rec.utils_table_recover import ( - format_html, - plot_rec_box, - plot_rec_box_with_logic_info, -) +from lineless_table_rec.main import RapidTableInput +from lineless_table_rec.utils.utils import VisTable output_dir = Path("outputs") output_dir.mkdir(parents=True, exist_ok=True) +input_args = RapidTableInput() +table_engine = LinelessTableRecognition(input_args) +ocr_engine = RapidOCR() +viser = VisTable() + +if __name__ == "__main__": + img_path = "tests/test_files/lineless_table_recognition.jpg" + + ocr_result, _ = ocr_engine(img_path) + boxes, txts, scores = list(zip(*ocr_result)) -img_path = "tests/test_files/lineless_table_recognition.jpg" -table_rec = LinelessTableRecognition() + # Table Rec + table_results = table_engine(img_path) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) -html, elasp, polygons, logic_points, ocr_res = table_rec(img_path) -print(f"cost: {elasp:.5f}") + # Save + save_dir = Path("outputs") + save_dir.mkdir(parents=True, exist_ok=True) -complete_html = format_html(html) + save_html_path = f"outputs/{Path(img_path).stem}.html" + save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + save_logic_path = ( + f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" + ) -save_table_path = output_dir / "table.html" -with open(save_table_path, "w", encoding="utf-8") as file: - file.write(complete_html) + # Visualize table rec result + vis_imged = viser( + img_path, table_results, save_html_path, save_drawed_path, save_logic_path + ) -plot_rec_box_with_logic_info( - img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons -) -plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) -print(f"The results has been saved under {output_dir}") + print(f"The results has been saved under {output_dir}") diff --git a/demo_wired.py b/demo_wired.py index eb99145..385ce39 100644 --- a/demo_wired.py +++ b/demo_wired.py @@ -3,32 +3,44 @@ # @Contact: liekkaskono@163.com from pathlib import Path +from rapidocr_onnxruntime import RapidOCR + from wired_table_rec import WiredTableRecognition -from wired_table_rec.utils_table_recover import ( - format_html, - plot_rec_box, - plot_rec_box_with_logic_info, -) +from wired_table_rec.main import RapidTableInput, ModelType +from wired_table_rec.utils.utils import VisTable output_dir = Path("outputs") output_dir.mkdir(parents=True, exist_ok=True) - -table_rec = WiredTableRecognition() - -img_path = "tests/test_files/wired/table1.png" -html, elasp, polygons, logic_points, ocr_res = table_rec(img_path) - -print(f"cost: {elasp:.5f}") - -complete_html = format_html(html) - -save_table_path = output_dir / "table.html" -with open(save_table_path, "w", encoding="utf-8") as file: - file.write(complete_html) - -plot_rec_box_with_logic_info( - img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons -) -plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) - -print(f"The results has been saved under {output_dir}") +input_args = RapidTableInput(model_type=ModelType.CYCLE_CENTER_NET.value) +table_engine = WiredTableRecognition(input_args) +ocr_engine = RapidOCR() +viser = VisTable() +if __name__ == "__main__": + img_path = "tests/test_files/wired/bad_case_1.png" + + ocr_result, _ = ocr_engine(img_path) + boxes, txts, scores = list(zip(*ocr_result)) + + # Table Rec + table_results = table_engine(img_path) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) + + # Save + save_dir = Path("outputs") + save_dir.mkdir(parents=True, exist_ok=True) + + save_html_path = f"outputs/{Path(img_path).stem}.html" + save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + save_logic_path = ( + f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" + ) + + # Visualize table rec result + vis_imged = viser( + img_path, table_results, save_html_path, save_drawed_path, save_logic_path + ) + + print(f"The results has been saved under {output_dir}") diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index 01d2007..c7514c9 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -1,19 +1,22 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import importlib import logging import time import traceback +from dataclasses import dataclass, asdict +from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Dict, List, Union, Optional import cv2 import numpy as np -from rapidocr_onnxruntime import RapidOCR -from .process import DetProcess, get_affine_transform_upper_left -from .utils import InputType, LoadImage, OrtInferSession -from .utils_table_recover import ( +from .table_structure_lore import TSRLore +from .utils.download_model import DownloadModel +from .utils.utils import InputType, LoadImage +from lineless_table_rec.utils.utils_table_recover import ( box_4_2_poly_to_box_4_1, filter_duplicated_box, gather_ocr_list_by_row, @@ -23,57 +26,76 @@ sorted_ocr_boxes, ) -cur_dir = Path(__file__).resolve().parent -detect_model_path = cur_dir / "models" / "lore_detect.onnx" -process_model_path = cur_dir / "models" / "lore_process.onnx" +class ModelType(Enum): + LORE = "lore" -class LinelessTableRecognition: - def __init__( - self, - detect_model_path: Union[str, Path] = detect_model_path, - process_model_path: Union[str, Path] = process_model_path, - ): - self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3) - self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3) - self.inp_h = 768 - self.inp_w = 768 +ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/" +KEY_TO_MODEL_URL = { + ModelType.LORE.value: { + "lore_detect": f"{ROOT_URL}/lore/detect.onnx", + "lore_process": f"{ROOT_URL}/lore/process.onnx", + }, +} + + +@dataclass +class RapidTableInput: + model_type: Optional[str] = ModelType.LORE.value + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + - self.det_session = OrtInferSession(detect_model_path) - self.process_session = OrtInferSession(process_model_path) +@dataclass +class RapidTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None + +class LinelessTableRecognition: + def __init__(self, config: RapidTableInput): + self.model_type = config.model_type + if self.model_type not in KEY_TO_MODEL_URL: + model_list = ",".join(KEY_TO_MODEL_URL) + raise ValueError( + f"{self.model_type} is not supported. The currently supported models are {model_list}." + ) + + config.model_path = self.get_model_path(config.model_type, config.model_path) + self.table_structure = TSRLore(asdict(config)) self.load_img = LoadImage() - self.det_process = DetProcess() - self.ocr = RapidOCR() + try: + self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR() + except ModuleNotFoundError: + self.ocr = None def __call__( self, content: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, - **kwargs - ): - ss = time.perf_counter() + **kwargs, + ) -> RapidTableOutput: + s = time.perf_counter() rec_again = True need_ocr = True if kwargs: rec_again = kwargs.get("rec_again", True) - need_ocr = kwargs.get("need_ocr", True) img = self.load_img(content) - input_info = self.preprocess(img) try: - polygons, slct_logi = self.infer(input_info) - logi_points = self.filter_logi_points(slct_logi) + polygons, logi_points = self.table_structure(img) if not need_ocr: sorted_polygons, idx_list = sorted_ocr_boxes( [box_4_2_poly_to_box_4_1(box) for box in polygons] ) - return ( + return RapidTableOutput( "", - time.perf_counter() - ss, sorted_polygons, logi_points[idx_list], - [], + time.perf_counter() - s, ) if ocr_result is None and need_ocr: @@ -103,32 +125,19 @@ def __call__( i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]] for i, t_box_ocr in enumerate(t_rec_ocr_list) } - table_str = plot_html_table(logi_points, cell_box_det_map) + pred_html = plot_html_table(logi_points, cell_box_det_map) # 输出可视化排序,用于验证结果,生产版本可以去掉 _, idx_list = sorted_ocr_boxes( [t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list] ) - t_rec_ocr_list = [t_rec_ocr_list[i] for i in idx_list] - sorted_polygons = [t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list] - sorted_logi_points = [ - t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list - ] - ocr_boxes_res = [ - box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result - ] - sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res) - table_elapse = time.perf_counter() - ss - return ( - table_str, - table_elapse, - sorted_polygons, - sorted_logi_points, - sorted_ocr_boxes_res, - ) + polygons = polygons.reshape(-1, 8) + logi_points = np.array(logi_points) + elapse = time.perf_counter() - s except Exception: logging.warning(traceback.format_exc()) - return "", 0.0, None, None, None + return RapidTableOutput("", None, None, 0.0) + return RapidTableOutput(pred_html, polygons, logi_points, elapse) def transform_res( self, @@ -159,48 +168,27 @@ def transform_res( res.append(dict_res) return res - def preprocess(self, img: np.ndarray) -> Dict[str, Any]: - height, width = img.shape[:2] - resized_image = cv2.resize(img, (width, height)) - - c = np.array([0, 0], dtype=np.float32) - s = max(height, width) * 1.0 - trans_input = get_affine_transform_upper_left(c, s, [self.inp_w, self.inp_h]) - - inp_image = cv2.warpAffine( - resized_image, trans_input, (self.inp_w, self.inp_h), flags=cv2.INTER_LINEAR - ) - inp_image = ((inp_image / 255.0 - self.mean) / self.std).astype(np.float32) - - images = inp_image.transpose(2, 0, 1).reshape(1, 3, self.inp_h, self.inp_w) - meta = { - "c": c, - "s": s, - "out_height": self.inp_h // 4, - "out_width": self.inp_w // 4, - } - return {"img": images, "meta": meta} + @staticmethod + def get_model_path( + model_type: str, model_path: Union[str, Path, None] + ) -> Union[str, Dict[str, str]]: + if model_path is not None: + return model_path - def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]: - hm, st, wh, ax, cr, reg = self.det_session([input_content["img"]]) - output = { - "hm": hm, - "st": st, - "wh": wh, - "ax": ax, - "cr": cr, - "reg": reg, - } - slct_logi_feat, slct_dets_feat, slct_output_dets = self.det_process( - output, input_content["meta"] - ) + model_url = KEY_TO_MODEL_URL.get(model_type, None) + if isinstance(model_url, str): + model_path = DownloadModel.download(model_url) + return model_path - slct_output_dets = slct_output_dets.reshape(-1, 4, 2) + if isinstance(model_url, dict): + model_paths = {} + for k, url in model_url.items(): + model_paths[k] = DownloadModel.download( + url, save_model_name=f"{model_type}_{Path(url).name}" + ) + return model_paths - _, slct_logi = self.process_session( - [slct_logi_feat, slct_dets_feat.astype(np.int64)] - ) - return slct_output_dets, slct_logi + raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.") def sort_and_gather_ocr_res(self, res): for i, dict_res in enumerate(res): @@ -254,23 +242,6 @@ def handle_overlap_row_col(self, res): res = [res[i] for i in range(len(res)) if i not in deleted_idx] return res, grid - @staticmethod - def filter_logi_points(slct_logi: np.ndarray) -> List[np.ndarray]: - for logic_points in slct_logi[0]: - # 修正坐标接近导致的r_e > r_s 或 c_e > c_s - if abs(logic_points[0] - logic_points[1]) < 0.2: - row = (logic_points[0] + logic_points[1]) / 2 - logic_points[0] = row - logic_points[1] = row - if abs(logic_points[2] - logic_points[3]) < 0.2: - col = (logic_points[2] + logic_points[3]) / 2 - logic_points[2] = col - logic_points[3] = col - logi_floor = np.floor(slct_logi) - dev = slct_logi - logi_floor - slct_logi = np.where(dev > 0.5, logi_floor + 1, logi_floor) - return slct_logi[0].astype(np.int32) - def re_rec( self, img: np.ndarray, diff --git a/lineless_table_rec/process.py b/lineless_table_rec/process.py index 7043d3c..ded8191 100644 --- a/lineless_table_rec/process.py +++ b/lineless_table_rec/process.py @@ -1,408 +1,408 @@ -# ------------------------------------------------------------------------------ -# Part of implementation is adopted from CenterNet, -# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git -# ------------------------------------------------------------------------------ -import warnings -from typing import Dict, List, Tuple, Union - -import cv2 -import numpy as np - -# suppress warnings -warnings.filterwarnings("ignore") - - -class DetProcess: - def __init__(self, K: int = 3000, num_classes: int = 2, scale: float = 1.0): - self.K = K - self.num_classes = num_classes - self.scale = scale - self.max_per_image = 3000 - - def __call__( - self, det_out: Dict[str, np.ndarray], meta: Dict[str, Union[int, np.ndarray]] - ): - hm = self.sigmoid(det_out["hm"]) - dets, keep, logi, cr = ctdet_4ps_decode( - hm[:, 0:1, :, :], - det_out["wh"], - det_out["ax"], - det_out["cr"], - reg=det_out["reg"], - K=self.K, - ) - - raw_dets = dets - dets = dets.reshape(1, -1, dets.shape[2]) - dets = ctdet_4ps_post_process_upper_left( - dets.copy(), - [meta["c"]], - [meta["s"]], - meta["out_height"], - meta["out_width"], - 2, - ) - for j in range(1, self.num_classes + 1): - dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 9) - dets[0][j][:, :8] /= self.scale - dets = dets[0] - detections = [dets] - - logi += cr - results = self.merge_outputs(detections) - slct_logi_feat, slct_dets_feat = self.filter(results, logi, raw_dets[:, :, :8]) - slct_output_dets = results[1][: slct_logi_feat.shape[1], :8] - return slct_logi_feat, slct_dets_feat, slct_output_dets - - @staticmethod - def sigmoid(data: np.ndarray) -> np.ndarray: - return 1 / (1 + np.exp(-data)) - - def merge_outputs(self, detections: Dict[int, np.ndarray]) -> Dict[int, np.ndarray]: - # thresh_conf, thresh_min, thresh_max = 0.1, 0.5, 0.7 - results = {} - for j in range(1, self.num_classes + 1): - results[j] = np.concatenate( - [detection[j] for detection in detections], axis=0 - ).astype(np.float32) - - scores = np.hstack([results[j][:, 8] for j in range(1, self.num_classes + 1)]) - if len(scores) > self.max_per_image: - kth = len(scores) - self.max_per_image - thresh = np.partition(scores, kth)[kth] - for j in range(1, self.num_classes + 1): - keep_inds = results[j][:, 8] >= thresh - results[j] = results[j][keep_inds] - return results - - @staticmethod - def filter( - results: Dict[int, np.ndarray], logi: np.ndarray, ps: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: - # this function select boxes - batch_size, feat_dim = logi.shape[0], logi.shape[2] - num_valid = sum(results[1][:, 8] >= 0.15) - - slct_logi = np.zeros((batch_size, num_valid, feat_dim), dtype=np.float32) - slct_dets = np.zeros((batch_size, num_valid, 8), dtype=np.int32) - for i in range(batch_size): - for j in range(num_valid): - slct_logi[i, j, :] = logi[i, j, :] - slct_dets[i, j, :] = ps[i, j, :] - - return slct_logi, slct_dets - - -def ctdet_4ps_decode( - heat: np.ndarray, - wh: np.ndarray, - ax: np.ndarray, - cr: np.ndarray, - reg: np.ndarray = None, - cat_spec_wh: bool = False, - K: int = 100, -): - batch, cat, _, width = heat.shape - heat, keep = _nms(heat) - scores, inds, clses, ys, xs = _topk(heat, K=K) - - if reg is not None: - reg = _tranpose_and_gather_feat(reg, inds) - reg = reg.reshape(batch, K, 2) - xs = xs.reshape(batch, K, 1) + reg[:, :, 0:1] - ys = ys.reshape(batch, K, 1) + reg[:, :, 1:2] - else: - xs = xs.reshape(batch, K, 1) + 0.5 - ys = ys.reshape(batch, K, 1) + 0.5 - - wh = _tranpose_and_gather_feat(wh, inds) - ax = _tranpose_and_gather_feat(ax, inds) - - if cat_spec_wh: - wh = wh.reshape(batch, K, cat, 8) - clses_ind = clses.reshape(batch, K, 1, 1).expand(batch, K, 1, 8) - wh = wh.gather(2, clses_ind).reshape(batch, K, 8) - else: - wh = wh.reshape(batch, K, 8) - - clses = clses.reshape(batch, K, 1) - scores = scores.reshape(batch, K, 1) - - bboxes_vec = [ - xs - wh[..., 0:1], - ys - wh[..., 1:2], - xs - wh[..., 2:3], - ys - wh[..., 3:4], - xs - wh[..., 4:5], - ys - wh[..., 5:6], - xs - wh[..., 6:7], - ys - wh[..., 7:8], - ] - bboxes = np.concatenate(bboxes_vec, axis=2) - - cc_match = np.concatenate( - [ - (xs - wh[..., 0:1]) + width * np.round(ys - wh[..., 1:2]), - (xs - wh[..., 2:3]) + width * np.round(ys - wh[..., 3:4]), - (xs - wh[..., 4:5]) + width * np.round(ys - wh[..., 5:6]), - (xs - wh[..., 6:7]) + width * np.round(ys - wh[..., 7:8]), - ], - axis=2, - ) - cc_match = np.round(cc_match).astype(np.int64) - cr_feat = _get_4ps_feat(cc_match, cr) - cr_feat = cr_feat.sum(axis=3) - - detections = np.concatenate([bboxes, scores, clses], axis=2) - return detections, keep, ax, cr_feat - - -def _nms(heat: np.ndarray, kernel: int = 3) -> Tuple[np.ndarray, np.ndarray]: - pad = (kernel - 1) // 2 - hmax = max_pool(heat, kernel_size=kernel, stride=1, padding=pad) - keep = hmax == heat - return heat * keep, keep - - -def max_pool( - img: np.ndarray, kernel_size: int, stride: int, padding: int -) -> np.ndarray: - h, w = img.shape[2:] - img = np.pad( - img, - ((0, 0), (0, 0), (padding, padding), (padding, padding)), - "constant", - constant_values=0, - ) - - res_h = ((h + 2 - kernel_size) // stride) + 1 - res_w = ((w + 2 - kernel_size) // stride) + 1 - res = np.zeros((img.shape[0], img.shape[1], res_h, res_w)) - for i in range(res_h): - for j in range(res_w): - temp = img[ - :, - :, - i * stride : i * stride + kernel_size, - j * stride : j * stride + kernel_size, - ] - res[:, :, i, j] = temp.max() - return res - - -def _topk( - scores: np.ndarray, K: int = 40 -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - batch, cat, height, width = scores.shape - - topk_scores, topk_inds = find_topk(scores.reshape(batch, cat, -1), K) - - topk_inds = topk_inds % (height * width) - topk_ys = topk_inds / width - topk_xs = np.float32(np.int32(topk_inds % width)) - - topk_score, topk_ind = find_topk(topk_scores.reshape(batch, -1), K) - topk_clses = np.int32(topk_ind / K) - topk_inds = _gather_feat(topk_inds.reshape(batch, -1, 1), topk_ind).reshape( - batch, K - ) - topk_ys = _gather_feat(topk_ys.reshape(batch, -1, 1), topk_ind).reshape(batch, K) - topk_xs = _gather_feat(topk_xs.reshape(batch, -1, 1), topk_ind).reshape(batch, K) - - return topk_score, topk_inds, topk_clses, topk_ys, topk_xs - - -def find_topk( - a: np.ndarray, k: int, axis: int = -1, largest: bool = True, sorted: bool = True -) -> Tuple[np.ndarray, np.ndarray]: - if axis is None: - axis_size = a.size - else: - axis_size = a.shape[axis] - assert 1 <= k <= axis_size - - a = np.asanyarray(a) - if largest: - index_array = np.argpartition(a, axis_size - k, axis=axis) - topk_indices = np.take(index_array, -np.arange(k) - 1, axis=axis) - else: - index_array = np.argpartition(a, k - 1, axis=axis) - topk_indices = np.take(index_array, np.arange(k), axis=axis) - - topk_values = np.take_along_axis(a, topk_indices, axis=axis) - if sorted: - sorted_indices_in_topk = np.argsort(topk_values, axis=axis) - if largest: - sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis) - - sorted_topk_values = np.take_along_axis( - topk_values, sorted_indices_in_topk, axis=axis - ) - sorted_topk_indices = np.take_along_axis( - topk_indices, sorted_indices_in_topk, axis=axis - ) - return sorted_topk_values, sorted_topk_indices - return topk_values, topk_indices - - -def _gather_feat(feat: np.ndarray, ind: np.ndarray) -> np.ndarray: - dim = feat.shape[2] - ind = np.broadcast_to(ind[:, :, None], (ind.shape[0], ind.shape[1], dim)) - feat = _gather(feat, 1, ind) - return feat - - -def _gather(data: np.ndarray, dim: int, index: np.ndarray) -> np.ndarray: - """ - Gathers values along an axis specified by dim. - For a 3-D tensor the output is specified by: - out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 - out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 - out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 - - :param dim: The axis along which to index - :param index: A tensor of indices of elements to gather - :return: tensor of gathered values - """ - idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1 :] - data_xsection_shape = data.shape[:dim] + data.shape[dim + 1 :] - if idx_xsection_shape != data_xsection_shape: - raise ValueError( - "Except for dimension " - + str(dim) - + ", all dimensions of index and data should be the same size" - ) - - if index.dtype != np.int64: - raise TypeError("The values of index must be integers") - - data_swaped = np.swapaxes(data, 0, dim) - index_swaped = np.swapaxes(index, 0, dim) - gathered = np.take_along_axis(data_swaped, index_swaped, axis=0) - return np.swapaxes(gathered, 0, dim) - - -def _tranpose_and_gather_feat(feat: np.ndarray, ind: np.ndarray) -> np.ndarray: - feat = np.ascontiguousarray(np.transpose(feat, [0, 2, 3, 1])) - feat = feat.reshape(feat.shape[0], -1, feat.shape[3]) - feat = _gather_feat(feat, ind) - return feat - - -def _get_4ps_feat(cc_match: np.ndarray, output: np.ndarray) -> np.ndarray: - if isinstance(output, dict): - feat = output["cr"] - else: - feat = output - - feat = np.ascontiguousarray(feat.transpose(0, 2, 3, 1)) - feat = feat.reshape(feat.shape[0], -1, feat.shape[3]) - feat = feat[..., None] - feat = np.concatenate([feat] * 4, axis=-1) - - dim = feat.shape[2] - cc_match = cc_match[..., None, :] - cc_match = np.concatenate([cc_match] * dim, axis=2) - if not (isinstance(output, dict)): - cc_match = np.where( - cc_match < feat.shape[1], - cc_match, - (feat.shape[0] - 1) * np.ones(cc_match.shape).astype(np.int64), - ) - - cc_match = np.where( - cc_match >= 0, cc_match, np.zeros(cc_match.shape).astype(np.int64) - ) - feat = np.take_along_axis(feat, cc_match, axis=1) - return feat - - -def ctdet_4ps_post_process_upper_left( - dets: np.ndarray, - c: List[np.ndarray], - s: List[float], - h: int, - w: int, - num_classes: int, -) -> np.ndarray: - # dets: batch x max_dets x dim - # return 1-based class det dict - ret = [] - for i in range(dets.shape[0]): - top_preds = {} - dets[i, :, 0:2] = transform_preds_upper_left( - dets[i, :, 0:2], c[i], s[i], (w, h) - ) - dets[i, :, 2:4] = transform_preds_upper_left( - dets[i, :, 2:4], c[i], s[i], (w, h) - ) - dets[i, :, 4:6] = transform_preds_upper_left( - dets[i, :, 4:6], c[i], s[i], (w, h) - ) - dets[i, :, 6:8] = transform_preds_upper_left( - dets[i, :, 6:8], c[i], s[i], (w, h) - ) - classes = dets[i, :, -1] - for j in range(num_classes): - inds = classes == j - tmp_top_pred = [ - dets[i, inds, :8].astype(np.float32), - dets[i, inds, 8:9].astype(np.float32), - ] - top_preds[j + 1] = np.concatenate(tmp_top_pred, axis=1).tolist() - ret.append(top_preds) - return ret - - -def transform_preds_upper_left( - coords: np.ndarray, - center: np.ndarray, - scale: float, - output_size: Tuple[int, int], -) -> np.ndarray: - target_coords = np.zeros(coords.shape) - - trans = get_affine_transform_upper_left(center, scale, output_size, inv=1) - for p in range(coords.shape[0]): - target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) - return target_coords - - -def get_affine_transform_upper_left( - center: np.ndarray, - scale: float, - output_size: List[Tuple[int, int]], - inv: int = 0, -) -> np.ndarray: - if not isinstance(scale, np.ndarray) and not isinstance(scale, list): - scale = np.array([scale, scale], dtype=np.float32) - - src = np.zeros((3, 2), dtype=np.float32) - dst = np.zeros((3, 2), dtype=np.float32) - src[0, :] = center - dst[0, :] = [0, 0] - if center[0] < center[1]: - src[1, :] = [scale[0], center[1]] - dst[1, :] = [output_size[0], 0] - else: - src[1, :] = [center[0], scale[0]] - dst[1, :] = [0, output_size[0]] - src[2:, :] = get_3rd_point(src[0, :], src[1, :]) - dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) - - if inv: - trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) - else: - trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) - return trans - - -def get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: - direct = a - b - return b + np.array([-direct[1], direct[0]], dtype=np.float32) - - -def affine_transform(pt: np.ndarray, t: np.ndarray) -> np.ndarray: - new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T - new_pt = np.dot(t, new_pt) - return new_pt[:2] +# # ------------------------------------------------------------------------------ +# # Part of implementation is adopted from CenterNet, +# # made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# # ------------------------------------------------------------------------------ +# import warnings +# from typing import Dict, List, Tuple, Union +# +# import cv2 +# import numpy as np +# +# # suppress warnings +# warnings.filterwarnings("ignore") +# +# +# class DetProcess: +# def __init__(self, K: int = 3000, num_classes: int = 2, scale: float = 1.0): +# self.K = K +# self.num_classes = num_classes +# self.scale = scale +# self.max_per_image = 3000 +# +# def __call__( +# self, det_out: Dict[str, np.ndarray], meta: Dict[str, Union[int, np.ndarray]] +# ): +# hm = self.sigmoid(det_out["hm"]) +# dets, keep, logi, cr = ctdet_4ps_decode( +# hm[:, 0:1, :, :], +# det_out["wh"], +# det_out["ax"], +# det_out["cr"], +# reg=det_out["reg"], +# K=self.K, +# ) +# +# raw_dets = dets +# dets = dets.reshape(1, -1, dets.shape[2]) +# dets = ctdet_4ps_post_process_upper_left( +# dets.copy(), +# [meta["c"]], +# [meta["s"]], +# meta["out_height"], +# meta["out_width"], +# 2, +# ) +# for j in range(1, self.num_classes + 1): +# dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 9) +# dets[0][j][:, :8] /= self.scale +# dets = dets[0] +# detections = [dets] +# +# logi += cr +# results = self.merge_outputs(detections) +# slct_logi_feat, slct_dets_feat = self.filter(results, logi, raw_dets[:, :, :8]) +# slct_output_dets = results[1][: slct_logi_feat.shape[1], :8] +# return slct_logi_feat, slct_dets_feat, slct_output_dets +# +# @staticmethod +# def sigmoid(data: np.ndarray) -> np.ndarray: +# return 1 / (1 + np.exp(-data)) +# +# def merge_outputs(self, detections: Dict[int, np.ndarray]) -> Dict[int, np.ndarray]: +# # thresh_conf, thresh_min, thresh_max = 0.1, 0.5, 0.7 +# results = {} +# for j in range(1, self.num_classes + 1): +# results[j] = np.concatenate( +# [detection[j] for detection in detections], axis=0 +# ).astype(np.float32) +# +# scores = np.hstack([results[j][:, 8] for j in range(1, self.num_classes + 1)]) +# if len(scores) > self.max_per_image: +# kth = len(scores) - self.max_per_image +# thresh = np.partition(scores, kth)[kth] +# for j in range(1, self.num_classes + 1): +# keep_inds = results[j][:, 8] >= thresh +# results[j] = results[j][keep_inds] +# return results +# +# @staticmethod +# def filter( +# results: Dict[int, np.ndarray], logi: np.ndarray, ps: np.ndarray +# ) -> Tuple[np.ndarray, np.ndarray]: +# # this function select boxes +# batch_size, feat_dim = logi.shape[0], logi.shape[2] +# num_valid = sum(results[1][:, 8] >= 0.15) +# +# slct_logi = np.zeros((batch_size, num_valid, feat_dim), dtype=np.float32) +# slct_dets = np.zeros((batch_size, num_valid, 8), dtype=np.int32) +# for i in range(batch_size): +# for j in range(num_valid): +# slct_logi[i, j, :] = logi[i, j, :] +# slct_dets[i, j, :] = ps[i, j, :] +# +# return slct_logi, slct_dets +# +# +# def ctdet_4ps_decode( +# heat: np.ndarray, +# wh: np.ndarray, +# ax: np.ndarray, +# cr: np.ndarray, +# reg: np.ndarray = None, +# cat_spec_wh: bool = False, +# K: int = 100, +# ): +# batch, cat, _, width = heat.shape +# heat, keep = _nms(heat) +# scores, inds, clses, ys, xs = _topk(heat, K=K) +# +# if reg is not None: +# reg = _tranpose_and_gather_feat(reg, inds) +# reg = reg.reshape(batch, K, 2) +# xs = xs.reshape(batch, K, 1) + reg[:, :, 0:1] +# ys = ys.reshape(batch, K, 1) + reg[:, :, 1:2] +# else: +# xs = xs.reshape(batch, K, 1) + 0.5 +# ys = ys.reshape(batch, K, 1) + 0.5 +# +# wh = _tranpose_and_gather_feat(wh, inds) +# ax = _tranpose_and_gather_feat(ax, inds) +# +# if cat_spec_wh: +# wh = wh.reshape(batch, K, cat, 8) +# clses_ind = clses.reshape(batch, K, 1, 1).expand(batch, K, 1, 8) +# wh = wh.gather(2, clses_ind).reshape(batch, K, 8) +# else: +# wh = wh.reshape(batch, K, 8) +# +# clses = clses.reshape(batch, K, 1) +# scores = scores.reshape(batch, K, 1) +# +# bboxes_vec = [ +# xs - wh[..., 0:1], +# ys - wh[..., 1:2], +# xs - wh[..., 2:3], +# ys - wh[..., 3:4], +# xs - wh[..., 4:5], +# ys - wh[..., 5:6], +# xs - wh[..., 6:7], +# ys - wh[..., 7:8], +# ] +# bboxes = np.concatenate(bboxes_vec, axis=2) +# +# cc_match = np.concatenate( +# [ +# (xs - wh[..., 0:1]) + width * np.round(ys - wh[..., 1:2]), +# (xs - wh[..., 2:3]) + width * np.round(ys - wh[..., 3:4]), +# (xs - wh[..., 4:5]) + width * np.round(ys - wh[..., 5:6]), +# (xs - wh[..., 6:7]) + width * np.round(ys - wh[..., 7:8]), +# ], +# axis=2, +# ) +# cc_match = np.round(cc_match).astype(np.int64) +# cr_feat = _get_4ps_feat(cc_match, cr) +# cr_feat = cr_feat.sum(axis=3) +# +# detections = np.concatenate([bboxes, scores, clses], axis=2) +# return detections, keep, ax, cr_feat +# +# +# def _nms(heat: np.ndarray, kernel: int = 3) -> Tuple[np.ndarray, np.ndarray]: +# pad = (kernel - 1) // 2 +# hmax = max_pool(heat, kernel_size=kernel, stride=1, padding=pad) +# keep = hmax == heat +# return heat * keep, keep +# +# +# def max_pool( +# img: np.ndarray, kernel_size: int, stride: int, padding: int +# ) -> np.ndarray: +# h, w = img.shape[2:] +# img = np.pad( +# img, +# ((0, 0), (0, 0), (padding, padding), (padding, padding)), +# "constant", +# constant_values=0, +# ) +# +# res_h = ((h + 2 - kernel_size) // stride) + 1 +# res_w = ((w + 2 - kernel_size) // stride) + 1 +# res = np.zeros((img.shape[0], img.shape[1], res_h, res_w)) +# for i in range(res_h): +# for j in range(res_w): +# temp = img[ +# :, +# :, +# i * stride : i * stride + kernel_size, +# j * stride : j * stride + kernel_size, +# ] +# res[:, :, i, j] = temp.max() +# return res +# +# +# def _topk( +# scores: np.ndarray, K: int = 40 +# ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +# batch, cat, height, width = scores.shape +# +# topk_scores, topk_inds = find_topk(scores.reshape(batch, cat, -1), K) +# +# topk_inds = topk_inds % (height * width) +# topk_ys = topk_inds / width +# topk_xs = np.float32(np.int32(topk_inds % width)) +# +# topk_score, topk_ind = find_topk(topk_scores.reshape(batch, -1), K) +# topk_clses = np.int32(topk_ind / K) +# topk_inds = _gather_feat(topk_inds.reshape(batch, -1, 1), topk_ind).reshape( +# batch, K +# ) +# topk_ys = _gather_feat(topk_ys.reshape(batch, -1, 1), topk_ind).reshape(batch, K) +# topk_xs = _gather_feat(topk_xs.reshape(batch, -1, 1), topk_ind).reshape(batch, K) +# +# return topk_score, topk_inds, topk_clses, topk_ys, topk_xs +# +# +# def find_topk( +# a: np.ndarray, k: int, axis: int = -1, largest: bool = True, sorted: bool = True +# ) -> Tuple[np.ndarray, np.ndarray]: +# if axis is None: +# axis_size = a.size +# else: +# axis_size = a.shape[axis] +# assert 1 <= k <= axis_size +# +# a = np.asanyarray(a) +# if largest: +# index_array = np.argpartition(a, axis_size - k, axis=axis) +# topk_indices = np.take(index_array, -np.arange(k) - 1, axis=axis) +# else: +# index_array = np.argpartition(a, k - 1, axis=axis) +# topk_indices = np.take(index_array, np.arange(k), axis=axis) +# +# topk_values = np.take_along_axis(a, topk_indices, axis=axis) +# if sorted: +# sorted_indices_in_topk = np.argsort(topk_values, axis=axis) +# if largest: +# sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis) +# +# sorted_topk_values = np.take_along_axis( +# topk_values, sorted_indices_in_topk, axis=axis +# ) +# sorted_topk_indices = np.take_along_axis( +# topk_indices, sorted_indices_in_topk, axis=axis +# ) +# return sorted_topk_values, sorted_topk_indices +# return topk_values, topk_indices +# +# +# def _gather_feat(feat: np.ndarray, ind: np.ndarray) -> np.ndarray: +# dim = feat.shape[2] +# ind = np.broadcast_to(ind[:, :, None], (ind.shape[0], ind.shape[1], dim)) +# feat = _gather(feat, 1, ind) +# return feat +# +# +# def _gather(data: np.ndarray, dim: int, index: np.ndarray) -> np.ndarray: +# """ +# Gathers values along an axis specified by dim. +# For a 3-D tensor the output is specified by: +# out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 +# out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 +# out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 +# +# :param dim: The axis along which to index +# :param index: A tensor of indices of elements to gather +# :return: tensor of gathered values +# """ +# idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1 :] +# data_xsection_shape = data.shape[:dim] + data.shape[dim + 1 :] +# if idx_xsection_shape != data_xsection_shape: +# raise ValueError( +# "Except for dimension " +# + str(dim) +# + ", all dimensions of index and data should be the same size" +# ) +# +# if index.dtype != np.int64: +# raise TypeError("The values of index must be integers") +# +# data_swaped = np.swapaxes(data, 0, dim) +# index_swaped = np.swapaxes(index, 0, dim) +# gathered = np.take_along_axis(data_swaped, index_swaped, axis=0) +# return np.swapaxes(gathered, 0, dim) +# +# +# def _tranpose_and_gather_feat(feat: np.ndarray, ind: np.ndarray) -> np.ndarray: +# feat = np.ascontiguousarray(np.transpose(feat, [0, 2, 3, 1])) +# feat = feat.reshape(feat.shape[0], -1, feat.shape[3]) +# feat = _gather_feat(feat, ind) +# return feat +# +# +# def _get_4ps_feat(cc_match: np.ndarray, output: np.ndarray) -> np.ndarray: +# if isinstance(output, dict): +# feat = output["cr"] +# else: +# feat = output +# +# feat = np.ascontiguousarray(feat.transpose(0, 2, 3, 1)) +# feat = feat.reshape(feat.shape[0], -1, feat.shape[3]) +# feat = feat[..., None] +# feat = np.concatenate([feat] * 4, axis=-1) +# +# dim = feat.shape[2] +# cc_match = cc_match[..., None, :] +# cc_match = np.concatenate([cc_match] * dim, axis=2) +# if not (isinstance(output, dict)): +# cc_match = np.where( +# cc_match < feat.shape[1], +# cc_match, +# (feat.shape[0] - 1) * np.ones(cc_match.shape).astype(np.int64), +# ) +# +# cc_match = np.where( +# cc_match >= 0, cc_match, np.zeros(cc_match.shape).astype(np.int64) +# ) +# feat = np.take_along_axis(feat, cc_match, axis=1) +# return feat +# +# +# def ctdet_4ps_post_process_upper_left( +# dets: np.ndarray, +# c: List[np.ndarray], +# s: List[float], +# h: int, +# w: int, +# num_classes: int, +# ) -> np.ndarray: +# # dets: batch x max_dets x dim +# # return 1-based class det dict +# ret = [] +# for i in range(dets.shape[0]): +# top_preds = {} +# dets[i, :, 0:2] = transform_preds_upper_left( +# dets[i, :, 0:2], c[i], s[i], (w, h) +# ) +# dets[i, :, 2:4] = transform_preds_upper_left( +# dets[i, :, 2:4], c[i], s[i], (w, h) +# ) +# dets[i, :, 4:6] = transform_preds_upper_left( +# dets[i, :, 4:6], c[i], s[i], (w, h) +# ) +# dets[i, :, 6:8] = transform_preds_upper_left( +# dets[i, :, 6:8], c[i], s[i], (w, h) +# ) +# classes = dets[i, :, -1] +# for j in range(num_classes): +# inds = classes == j +# tmp_top_pred = [ +# dets[i, inds, :8].astype(np.float32), +# dets[i, inds, 8:9].astype(np.float32), +# ] +# top_preds[j + 1] = np.concatenate(tmp_top_pred, axis=1).tolist() +# ret.append(top_preds) +# return ret +# +# +# def transform_preds_upper_left( +# coords: np.ndarray, +# center: np.ndarray, +# scale: float, +# output_size: Tuple[int, int], +# ) -> np.ndarray: +# target_coords = np.zeros(coords.shape) +# +# trans = get_affine_transform_upper_left(center, scale, output_size, inv=1) +# for p in range(coords.shape[0]): +# target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) +# return target_coords +# +# +# def get_affine_transform_upper_left( +# center: np.ndarray, +# scale: float, +# output_size: List[Tuple[int, int]], +# inv: int = 0, +# ) -> np.ndarray: +# if not isinstance(scale, np.ndarray) and not isinstance(scale, list): +# scale = np.array([scale, scale], dtype=np.float32) +# +# src = np.zeros((3, 2), dtype=np.float32) +# dst = np.zeros((3, 2), dtype=np.float32) +# src[0, :] = center +# dst[0, :] = [0, 0] +# if center[0] < center[1]: +# src[1, :] = [scale[0], center[1]] +# dst[1, :] = [output_size[0], 0] +# else: +# src[1, :] = [center[0], scale[0]] +# dst[1, :] = [0, output_size[0]] +# src[2:, :] = get_3rd_point(src[0, :], src[1, :]) +# dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) +# +# if inv: +# trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) +# else: +# trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) +# return trans +# +# +# def get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: +# direct = a - b +# return b + np.array([-direct[1], direct[0]], dtype=np.float32) +# +# +# def affine_transform(pt: np.ndarray, t: np.ndarray) -> np.ndarray: +# new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T +# new_pt = np.dot(t, new_pt) +# return new_pt[:2] diff --git a/lineless_table_rec/table_structure_lore.py b/lineless_table_rec/table_structure_lore.py new file mode 100644 index 0000000..eb14dd3 --- /dev/null +++ b/lineless_table_rec/table_structure_lore.py @@ -0,0 +1,92 @@ +from copy import deepcopy +from typing import Dict, Any, Tuple, Optional + +import cv2 +import numpy as np + +from .utils.utils import OrtInferSession +from .utils.utils_table_lore_rec import DetProcess, get_affine_transform_upper_left + + +class TSRLore: + def __init__(self, config: Dict): + self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3) + self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3) + + self.inp_h = 768 + self.inp_w = 768 + + det_config = deepcopy(config) + process_config = deepcopy(config) + det_config["model_path"] = config["model_path"]["lore_detect"] + process_config["model_path"] = config["model_path"]["lore_process"] + self.det_session = OrtInferSession(det_config) + self.process_session = OrtInferSession(process_config) + self.det_process = DetProcess() + + def __call__( + self, img: np.ndarray, **kwargs + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + img_info = self.preprocess(img) + polygons, slct_logi = self.infer(img_info) + logi_points = self.postprocess(slct_logi) + return polygons, logi_points + + def preprocess(self, img: np.ndarray) -> Dict[str, Any]: + height, width = img.shape[:2] + resized_image = cv2.resize(img, (width, height)) + + c = np.array([0, 0], dtype=np.float32) + s = max(height, width) * 1.0 + trans_input = get_affine_transform_upper_left(c, s, [self.inp_w, self.inp_h]) + + inp_image = cv2.warpAffine( + resized_image, trans_input, (self.inp_w, self.inp_h), flags=cv2.INTER_LINEAR + ) + inp_image = ((inp_image / 255.0 - self.mean) / self.std).astype(np.float32) + + images = inp_image.transpose(2, 0, 1).reshape(1, 3, self.inp_h, self.inp_w) + meta = { + "c": c, + "s": s, + "out_height": self.inp_h // 4, + "out_width": self.inp_w // 4, + } + return {"img": images, "meta": meta} + + def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]: + hm, st, wh, ax, cr, reg = self.det_session([input_content["img"]]) + output = { + "hm": hm, + "st": st, + "wh": wh, + "ax": ax, + "cr": cr, + "reg": reg, + } + slct_logi_feat, slct_dets_feat, slct_output_dets = self.det_process( + output, input_content["meta"] + ) + + slct_output_dets = slct_output_dets.reshape(-1, 4, 2) + + _, slct_logi = self.process_session( + [slct_logi_feat, slct_dets_feat.astype(np.int64)] + ) + return slct_output_dets, slct_logi + + def postprocess(self, slct_logi: np.ndarray) -> np.ndarray: + for logic_points in slct_logi[0]: + # 修正坐标接近导致的r_e > r_s 或 c_e > c_s + if abs(logic_points[0] - logic_points[1]) < 0.2: + row = (logic_points[0] + logic_points[1]) / 2 + logic_points[0] = row + logic_points[1] = row + if abs(logic_points[2] - logic_points[3]) < 0.2: + col = (logic_points[2] + logic_points[3]) / 2 + logic_points[2] = col + logic_points[3] = col + logi_floor = np.floor(slct_logi) + dev = slct_logi - logi_floor + slct_logi = np.where(dev > 0.5, logi_floor + 1, logi_floor) + return slct_logi[0].astype(np.int32) diff --git a/lineless_table_rec/utils.py b/lineless_table_rec/utils.py deleted file mode 100644 index d69f944..0000000 --- a/lineless_table_rec/utils.py +++ /dev/null @@ -1,180 +0,0 @@ -# -*- encoding: utf-8 -*- -import traceback -from io import BytesIO -from pathlib import Path -from typing import List, Union - -import cv2 -import numpy as np -from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions -from PIL import Image, UnidentifiedImageError - -root_dir = Path(__file__).resolve().parent -InputType = Union[str, np.ndarray, bytes, Path, Image.Image] - - -class OrtInferSession: - def __init__(self, model_path: Union[str, Path], num_threads: int = -1): - self.verify_exist(model_path) - - self.num_threads = num_threads - self._init_sess_opt() - - cpu_ep = "CPUExecutionProvider" - cpu_provider_options = { - "arena_extend_strategy": "kSameAsRequested", - } - EP_list = [(cpu_ep, cpu_provider_options)] - try: - self.session = InferenceSession( - str(model_path), sess_options=self.sess_opt, providers=EP_list - ) - except TypeError: - # 这里兼容ort 1.5.2 - self.session = InferenceSession(str(model_path), sess_options=self.sess_opt) - - def _init_sess_opt(self): - self.sess_opt = SessionOptions() - self.sess_opt.log_severity_level = 4 - self.sess_opt.enable_cpu_mem_arena = False - - if self.num_threads != -1: - self.sess_opt.intra_op_num_threads = self.num_threads - - self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - - def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: - input_dict = dict(zip(self.get_input_names(), input_content)) - try: - return self.session.run(None, input_dict) - except Exception as e: - error_info = traceback.format_exc() - raise ONNXRuntimeError(error_info) from e - - def get_input_names( - self, - ): - return [v.name for v in self.session.get_inputs()] - - def get_output_name(self, output_idx=0): - return self.session.get_outputs()[output_idx].name - - def get_metadata(self): - meta_dict = self.session.get_modelmeta().custom_metadata_map - return meta_dict - - @staticmethod - def verify_exist(model_path: Union[Path, str]): - if not isinstance(model_path, Path): - model_path = Path(model_path) - - if not model_path.exists(): - raise FileNotFoundError(f"{model_path} does not exist!") - - if not model_path.is_file(): - raise FileExistsError(f"{model_path} must be a file") - - -class ONNXRuntimeError(Exception): - pass - - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - origin_img_type = type(img) - img = self.load_img(img) - img = self.convert_img(img, origin_img_type) - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = np.array(Image.open(img)) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = np.array(Image.open(BytesIO(img))) - return img - - if isinstance(img, np.ndarray): - return img - - if isinstance(img, Image.Image): - return np.array(img) - - raise LoadImageError(f"{type(img)} is not supported!") - - def convert_img(self, img: np.ndarray, origin_img_type): - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3: - channel = img.shape[2] - if channel == 1: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if channel == 2: - return self.cvt_two_to_three(img) - - if channel == 3: - if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): - return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if channel == 4: - return self.cvt_four_to_three(img) - - raise LoadImageError( - f"The channel({channel}) of the img is not in [1, 2, 3, 4]" - ) - - raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") - - @staticmethod - def cvt_two_to_three(img: np.ndarray) -> np.ndarray: - """gray + alpha → BGR""" - img_gray = img[..., 0] - img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) - - img_alpha = img[..., 1] - not_a = cv2.bitwise_not(img_alpha) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → BGR""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass diff --git a/lineless_table_rec/utils/__init__.py b/lineless_table_rec/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lineless_table_rec/utils/download_model.py b/lineless_table_rec/utils/download_model.py new file mode 100644 index 0000000..adedb5d --- /dev/null +++ b/lineless_table_rec/utils/download_model.py @@ -0,0 +1,67 @@ +import io +from pathlib import Path +from typing import Optional, Union + +import requests +from tqdm import tqdm + +from .logger import get_logger + +logger = get_logger("DownloadModel") + +PROJECT_DIR = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = PROJECT_DIR / "models" + + +class DownloadModel: + @classmethod + def download( + cls, + model_full_url: Union[str, Path], + save_dir: Union[str, Path, None] = None, + save_model_name: Optional[str] = None, + ) -> str: + if save_dir is None: + save_dir = DEFAULT_MODEL_DIR + + save_dir.mkdir(parents=True, exist_ok=True) + + if save_model_name is None: + save_model_name = Path(model_full_url).name + + save_file_path = save_dir / save_model_name + if save_file_path.exists(): + logger.debug("%s already exists", save_file_path) + return str(save_file_path) + + try: + logger.info("Download %s to %s", model_full_url, save_dir) + file = cls.download_as_bytes_with_progress(model_full_url, save_model_name) + cls.save_file(save_file_path, file) + except Exception as exc: + raise DownloadModelError from exc + return str(save_file_path) + + @staticmethod + def download_as_bytes_with_progress( + url: Union[str, Path], name: Optional[str] = None + ) -> bytes: + resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) + total = int(resp.headers.get("content-length", 0)) + bio = io.BytesIO() + with tqdm( + desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 + ) as pbar: + for chunk in resp.iter_content(chunk_size=65536): + pbar.update(len(chunk)) + bio.write(chunk) + return bio.getvalue() + + @staticmethod + def save_file(save_path: Union[str, Path], file: bytes): + with open(save_path, "wb") as f: + f.write(file) + + +class DownloadModelError(Exception): + pass diff --git a/lineless_table_rec/utils/logger.py b/lineless_table_rec/utils/logger.py new file mode 100644 index 0000000..2950987 --- /dev/null +++ b/lineless_table_rec/utils/logger.py @@ -0,0 +1,21 @@ +# -*- encoding: utf-8 -*- +# @Author: Jocker1212 +# @Contact: xinyijianggo@gmail.com +import logging +from functools import lru_cache + + +@lru_cache(maxsize=32) +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/lineless_table_rec/utils/utils.py b/lineless_table_rec/utils/utils.py new file mode 100644 index 0000000..9fd55dd --- /dev/null +++ b/lineless_table_rec/utils/utils.py @@ -0,0 +1,477 @@ +# -*- encoding: utf-8 -*- +import os +import platform +import traceback +from enum import Enum +from io import BytesIO +from pathlib import Path +from typing import List, Union, Dict, Any, Tuple, Optional + +import cv2 +import numpy as np +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) +from PIL import Image, UnidentifiedImageError + +from .logger import get_logger + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class EP(Enum): + CPU_EP = "CPUExecutionProvider" + CUDA_EP = "CUDAExecutionProvider" + DIRECTML_EP = "DmlExecutionProvider" + + +class OrtInferSession: + def __init__(self, config: Dict[str, Any]): + self.logger = get_logger("OrtInferSession") + + model_path = config.get("model_path", None) + self._verify_model(model_path) + + self.cfg_use_cuda = config.get("use_cuda", None) + self.cfg_use_dml = config.get("use_dml", None) + + self.had_providers: List[str] = get_available_providers() + EP_list = self._get_ep_list() + + sess_opt = self._init_sess_opts(config) + self.session = InferenceSession( + model_path, + sess_options=sess_opt, + providers=EP_list, + ) + self._verify_providers() + + @staticmethod + def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: + sess_opt = SessionOptions() + sess_opt.log_severity_level = 4 + sess_opt.enable_cpu_mem_arena = False + sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + cpu_nums = os.cpu_count() + intra_op_num_threads = config.get("intra_op_num_threads", -1) + if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: + sess_opt.intra_op_num_threads = intra_op_num_threads + + inter_op_num_threads = config.get("inter_op_num_threads", -1) + if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: + sess_opt.inter_op_num_threads = inter_op_num_threads + + return sess_opt + + def get_metadata(self, key: str = "character") -> list: + meta_dict = self.session.get_modelmeta().custom_metadata_map + content_list = meta_dict[key].splitlines() + return content_list + + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: + cpu_provider_opts = { + "arena_extend_strategy": "kSameAsRequested", + } + EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] + + cuda_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + } + self.use_cuda = self._check_cuda() + if self.use_cuda: + EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) + + self.use_directml = self._check_dml() + if self.use_directml: + self.logger.info( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) + EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) + return EP_list + + def _check_cuda(self) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.CUDA_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") + self.logger.info( + "(For reference only) If you want to use GPU acceleration, you must do:" + ) + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + EP.CUDA_EP.value, + ) + return False + + def _check_dml(self) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", + cur_os, + self.had_providers[0], + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", + cur_window_version, + self.had_providers[0], + ) + return False + + if EP.DIRECTML_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.DIRECTML_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + EP.DIRECTML_EP.value, + ) + return False + + def _verify_providers(self): + session_providers = self.session.get_providers() + first_provider = session_providers[0] + + if self.use_cuda and first_provider != EP.CUDA_EP.value: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + EP.CUDA_EP.value, + first_provider, + ) + + if self.use_directml and first_provider != EP.DIRECTML_EP.value: + self.logger.warning( + "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", + EP.DIRECTML_EP.value, + first_provider, + ) + + def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), input_content)) + try: + return self.session.run(None, input_dict) + except Exception as e: + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e + + def get_input_names(self) -> List[str]: + return [v.name for v in self.session.get_inputs()] + + def get_output_names(self) -> List[str]: + return [v.name for v in self.session.get_outputs()] + + def get_character_list(self, key: str = "character") -> List[str]: + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): + return True + return False + + @staticmethod + def _verify_model(model_path: Union[str, Path, None]): + if model_path is None: + raise ValueError("model_path is None!") + + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +class ONNXRuntimeError(Exception): + pass + + +class LoadImage: + def __init__( + self, + ): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + origin_img_type = type(img) + img = self.load_img(img) + img = self.convert_img(img, origin_img_type) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = np.array(Image.open(img)) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = np.array(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + if isinstance(img, Image.Image): + return np.array(img) + + raise LoadImageError(f"{type(img)} is not supported!") + + def convert_img(self, img: np.ndarray, origin_img_type): + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 3: + if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if channel == 4: + return self.cvt_four_to_three(img) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass + + +class VisTable: + def __init__(self): + self.load_img = LoadImage() + + def __call__( + self, + img_path: Union[str, Path], + table_results, + save_html_path: Optional[Union[str, Path]] = None, + save_drawed_path: Optional[Union[str, Path]] = None, + save_logic_path: Optional[Union[str, Path]] = None, + ): + if save_html_path: + html_with_border = self.insert_border_style(table_results.pred_html) + self.save_html(save_html_path, html_with_border) + + table_cell_bboxes = table_results.cell_bboxes + table_logic_points = table_results.logic_points + if table_cell_bboxes is None: + return None + + img = self.load_img(img_path) + + dims_bboxes = table_cell_bboxes.shape[1] + if dims_bboxes == 4: + drawed_img = self.draw_rectangle(img, table_cell_bboxes) + elif dims_bboxes == 8: + drawed_img = self.draw_polylines(img, table_cell_bboxes) + else: + raise ValueError("Shape of table bounding boxes is not between in 4 or 8.") + + if save_drawed_path: + self.save_img(save_drawed_path, drawed_img) + + if save_logic_path: + polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes] + self.plot_rec_box_with_logic_info( + img_path, save_logic_path, table_logic_points, polygons + ) + return drawed_img + + def insert_border_style(self, table_html_str: str): + style_res = """""" + + prefix_table, suffix_table = table_html_str.split("") + html_with_border = f"{prefix_table}{style_res}{suffix_table}" + return html_with_border + + def plot_rec_box_with_logic_info( + self, img_path, output_path, logic_points, sorted_polygons + ): + """ + :param img_path + :param output_path + :param logic_points: [row_start,row_end,col_start,col_end] + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 读取原图 + img = cv2.imread(img_path) + img = cv2.copyMakeBorder( + img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + # 绘制 polygons 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 0.9 # 原先是0.5 + thickness = 1 # 原先是1 + logic_point = logic_points[idx] + cv2.putText( + img, + f"row: {logic_point[0]}-{logic_point[1]}", + (x0 + 3, y0 + 8), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + cv2.putText( + img, + f"col: {logic_point[2]}-{logic_point[3]}", + (x0 + 3, y0 + 18), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + self.save_img(output_path, img) + + @staticmethod + def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: + img_copy = img.copy() + for box in boxes.astype(int): + x1, y1, x2, y2 = box + cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2) + return img_copy + + @staticmethod + def draw_polylines(img: np.ndarray, points) -> np.ndarray: + img_copy = img.copy() + for point in points.astype(int): + point = point.reshape(4, 2) + cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2) + return img_copy + + @staticmethod + def save_img(save_path: Union[str, Path], img: np.ndarray): + cv2.imwrite(str(save_path), img) + + @staticmethod + def save_html(save_path: Union[str, Path], html: str): + with open(save_path, "w", encoding="utf-8") as f: + f.write(html) diff --git a/lineless_table_rec/utils/utils_table_lore_rec.py b/lineless_table_rec/utils/utils_table_lore_rec.py new file mode 100644 index 0000000..7043d3c --- /dev/null +++ b/lineless_table_rec/utils/utils_table_lore_rec.py @@ -0,0 +1,408 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from CenterNet, +# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# ------------------------------------------------------------------------------ +import warnings +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np + +# suppress warnings +warnings.filterwarnings("ignore") + + +class DetProcess: + def __init__(self, K: int = 3000, num_classes: int = 2, scale: float = 1.0): + self.K = K + self.num_classes = num_classes + self.scale = scale + self.max_per_image = 3000 + + def __call__( + self, det_out: Dict[str, np.ndarray], meta: Dict[str, Union[int, np.ndarray]] + ): + hm = self.sigmoid(det_out["hm"]) + dets, keep, logi, cr = ctdet_4ps_decode( + hm[:, 0:1, :, :], + det_out["wh"], + det_out["ax"], + det_out["cr"], + reg=det_out["reg"], + K=self.K, + ) + + raw_dets = dets + dets = dets.reshape(1, -1, dets.shape[2]) + dets = ctdet_4ps_post_process_upper_left( + dets.copy(), + [meta["c"]], + [meta["s"]], + meta["out_height"], + meta["out_width"], + 2, + ) + for j in range(1, self.num_classes + 1): + dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 9) + dets[0][j][:, :8] /= self.scale + dets = dets[0] + detections = [dets] + + logi += cr + results = self.merge_outputs(detections) + slct_logi_feat, slct_dets_feat = self.filter(results, logi, raw_dets[:, :, :8]) + slct_output_dets = results[1][: slct_logi_feat.shape[1], :8] + return slct_logi_feat, slct_dets_feat, slct_output_dets + + @staticmethod + def sigmoid(data: np.ndarray) -> np.ndarray: + return 1 / (1 + np.exp(-data)) + + def merge_outputs(self, detections: Dict[int, np.ndarray]) -> Dict[int, np.ndarray]: + # thresh_conf, thresh_min, thresh_max = 0.1, 0.5, 0.7 + results = {} + for j in range(1, self.num_classes + 1): + results[j] = np.concatenate( + [detection[j] for detection in detections], axis=0 + ).astype(np.float32) + + scores = np.hstack([results[j][:, 8] for j in range(1, self.num_classes + 1)]) + if len(scores) > self.max_per_image: + kth = len(scores) - self.max_per_image + thresh = np.partition(scores, kth)[kth] + for j in range(1, self.num_classes + 1): + keep_inds = results[j][:, 8] >= thresh + results[j] = results[j][keep_inds] + return results + + @staticmethod + def filter( + results: Dict[int, np.ndarray], logi: np.ndarray, ps: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + # this function select boxes + batch_size, feat_dim = logi.shape[0], logi.shape[2] + num_valid = sum(results[1][:, 8] >= 0.15) + + slct_logi = np.zeros((batch_size, num_valid, feat_dim), dtype=np.float32) + slct_dets = np.zeros((batch_size, num_valid, 8), dtype=np.int32) + for i in range(batch_size): + for j in range(num_valid): + slct_logi[i, j, :] = logi[i, j, :] + slct_dets[i, j, :] = ps[i, j, :] + + return slct_logi, slct_dets + + +def ctdet_4ps_decode( + heat: np.ndarray, + wh: np.ndarray, + ax: np.ndarray, + cr: np.ndarray, + reg: np.ndarray = None, + cat_spec_wh: bool = False, + K: int = 100, +): + batch, cat, _, width = heat.shape + heat, keep = _nms(heat) + scores, inds, clses, ys, xs = _topk(heat, K=K) + + if reg is not None: + reg = _tranpose_and_gather_feat(reg, inds) + reg = reg.reshape(batch, K, 2) + xs = xs.reshape(batch, K, 1) + reg[:, :, 0:1] + ys = ys.reshape(batch, K, 1) + reg[:, :, 1:2] + else: + xs = xs.reshape(batch, K, 1) + 0.5 + ys = ys.reshape(batch, K, 1) + 0.5 + + wh = _tranpose_and_gather_feat(wh, inds) + ax = _tranpose_and_gather_feat(ax, inds) + + if cat_spec_wh: + wh = wh.reshape(batch, K, cat, 8) + clses_ind = clses.reshape(batch, K, 1, 1).expand(batch, K, 1, 8) + wh = wh.gather(2, clses_ind).reshape(batch, K, 8) + else: + wh = wh.reshape(batch, K, 8) + + clses = clses.reshape(batch, K, 1) + scores = scores.reshape(batch, K, 1) + + bboxes_vec = [ + xs - wh[..., 0:1], + ys - wh[..., 1:2], + xs - wh[..., 2:3], + ys - wh[..., 3:4], + xs - wh[..., 4:5], + ys - wh[..., 5:6], + xs - wh[..., 6:7], + ys - wh[..., 7:8], + ] + bboxes = np.concatenate(bboxes_vec, axis=2) + + cc_match = np.concatenate( + [ + (xs - wh[..., 0:1]) + width * np.round(ys - wh[..., 1:2]), + (xs - wh[..., 2:3]) + width * np.round(ys - wh[..., 3:4]), + (xs - wh[..., 4:5]) + width * np.round(ys - wh[..., 5:6]), + (xs - wh[..., 6:7]) + width * np.round(ys - wh[..., 7:8]), + ], + axis=2, + ) + cc_match = np.round(cc_match).astype(np.int64) + cr_feat = _get_4ps_feat(cc_match, cr) + cr_feat = cr_feat.sum(axis=3) + + detections = np.concatenate([bboxes, scores, clses], axis=2) + return detections, keep, ax, cr_feat + + +def _nms(heat: np.ndarray, kernel: int = 3) -> Tuple[np.ndarray, np.ndarray]: + pad = (kernel - 1) // 2 + hmax = max_pool(heat, kernel_size=kernel, stride=1, padding=pad) + keep = hmax == heat + return heat * keep, keep + + +def max_pool( + img: np.ndarray, kernel_size: int, stride: int, padding: int +) -> np.ndarray: + h, w = img.shape[2:] + img = np.pad( + img, + ((0, 0), (0, 0), (padding, padding), (padding, padding)), + "constant", + constant_values=0, + ) + + res_h = ((h + 2 - kernel_size) // stride) + 1 + res_w = ((w + 2 - kernel_size) // stride) + 1 + res = np.zeros((img.shape[0], img.shape[1], res_h, res_w)) + for i in range(res_h): + for j in range(res_w): + temp = img[ + :, + :, + i * stride : i * stride + kernel_size, + j * stride : j * stride + kernel_size, + ] + res[:, :, i, j] = temp.max() + return res + + +def _topk( + scores: np.ndarray, K: int = 40 +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + batch, cat, height, width = scores.shape + + topk_scores, topk_inds = find_topk(scores.reshape(batch, cat, -1), K) + + topk_inds = topk_inds % (height * width) + topk_ys = topk_inds / width + topk_xs = np.float32(np.int32(topk_inds % width)) + + topk_score, topk_ind = find_topk(topk_scores.reshape(batch, -1), K) + topk_clses = np.int32(topk_ind / K) + topk_inds = _gather_feat(topk_inds.reshape(batch, -1, 1), topk_ind).reshape( + batch, K + ) + topk_ys = _gather_feat(topk_ys.reshape(batch, -1, 1), topk_ind).reshape(batch, K) + topk_xs = _gather_feat(topk_xs.reshape(batch, -1, 1), topk_ind).reshape(batch, K) + + return topk_score, topk_inds, topk_clses, topk_ys, topk_xs + + +def find_topk( + a: np.ndarray, k: int, axis: int = -1, largest: bool = True, sorted: bool = True +) -> Tuple[np.ndarray, np.ndarray]: + if axis is None: + axis_size = a.size + else: + axis_size = a.shape[axis] + assert 1 <= k <= axis_size + + a = np.asanyarray(a) + if largest: + index_array = np.argpartition(a, axis_size - k, axis=axis) + topk_indices = np.take(index_array, -np.arange(k) - 1, axis=axis) + else: + index_array = np.argpartition(a, k - 1, axis=axis) + topk_indices = np.take(index_array, np.arange(k), axis=axis) + + topk_values = np.take_along_axis(a, topk_indices, axis=axis) + if sorted: + sorted_indices_in_topk = np.argsort(topk_values, axis=axis) + if largest: + sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis) + + sorted_topk_values = np.take_along_axis( + topk_values, sorted_indices_in_topk, axis=axis + ) + sorted_topk_indices = np.take_along_axis( + topk_indices, sorted_indices_in_topk, axis=axis + ) + return sorted_topk_values, sorted_topk_indices + return topk_values, topk_indices + + +def _gather_feat(feat: np.ndarray, ind: np.ndarray) -> np.ndarray: + dim = feat.shape[2] + ind = np.broadcast_to(ind[:, :, None], (ind.shape[0], ind.shape[1], dim)) + feat = _gather(feat, 1, ind) + return feat + + +def _gather(data: np.ndarray, dim: int, index: np.ndarray) -> np.ndarray: + """ + Gathers values along an axis specified by dim. + For a 3-D tensor the output is specified by: + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :param dim: The axis along which to index + :param index: A tensor of indices of elements to gather + :return: tensor of gathered values + """ + idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1 :] + data_xsection_shape = data.shape[:dim] + data.shape[dim + 1 :] + if idx_xsection_shape != data_xsection_shape: + raise ValueError( + "Except for dimension " + + str(dim) + + ", all dimensions of index and data should be the same size" + ) + + if index.dtype != np.int64: + raise TypeError("The values of index must be integers") + + data_swaped = np.swapaxes(data, 0, dim) + index_swaped = np.swapaxes(index, 0, dim) + gathered = np.take_along_axis(data_swaped, index_swaped, axis=0) + return np.swapaxes(gathered, 0, dim) + + +def _tranpose_and_gather_feat(feat: np.ndarray, ind: np.ndarray) -> np.ndarray: + feat = np.ascontiguousarray(np.transpose(feat, [0, 2, 3, 1])) + feat = feat.reshape(feat.shape[0], -1, feat.shape[3]) + feat = _gather_feat(feat, ind) + return feat + + +def _get_4ps_feat(cc_match: np.ndarray, output: np.ndarray) -> np.ndarray: + if isinstance(output, dict): + feat = output["cr"] + else: + feat = output + + feat = np.ascontiguousarray(feat.transpose(0, 2, 3, 1)) + feat = feat.reshape(feat.shape[0], -1, feat.shape[3]) + feat = feat[..., None] + feat = np.concatenate([feat] * 4, axis=-1) + + dim = feat.shape[2] + cc_match = cc_match[..., None, :] + cc_match = np.concatenate([cc_match] * dim, axis=2) + if not (isinstance(output, dict)): + cc_match = np.where( + cc_match < feat.shape[1], + cc_match, + (feat.shape[0] - 1) * np.ones(cc_match.shape).astype(np.int64), + ) + + cc_match = np.where( + cc_match >= 0, cc_match, np.zeros(cc_match.shape).astype(np.int64) + ) + feat = np.take_along_axis(feat, cc_match, axis=1) + return feat + + +def ctdet_4ps_post_process_upper_left( + dets: np.ndarray, + c: List[np.ndarray], + s: List[float], + h: int, + w: int, + num_classes: int, +) -> np.ndarray: + # dets: batch x max_dets x dim + # return 1-based class det dict + ret = [] + for i in range(dets.shape[0]): + top_preds = {} + dets[i, :, 0:2] = transform_preds_upper_left( + dets[i, :, 0:2], c[i], s[i], (w, h) + ) + dets[i, :, 2:4] = transform_preds_upper_left( + dets[i, :, 2:4], c[i], s[i], (w, h) + ) + dets[i, :, 4:6] = transform_preds_upper_left( + dets[i, :, 4:6], c[i], s[i], (w, h) + ) + dets[i, :, 6:8] = transform_preds_upper_left( + dets[i, :, 6:8], c[i], s[i], (w, h) + ) + classes = dets[i, :, -1] + for j in range(num_classes): + inds = classes == j + tmp_top_pred = [ + dets[i, inds, :8].astype(np.float32), + dets[i, inds, 8:9].astype(np.float32), + ] + top_preds[j + 1] = np.concatenate(tmp_top_pred, axis=1).tolist() + ret.append(top_preds) + return ret + + +def transform_preds_upper_left( + coords: np.ndarray, + center: np.ndarray, + scale: float, + output_size: Tuple[int, int], +) -> np.ndarray: + target_coords = np.zeros(coords.shape) + + trans = get_affine_transform_upper_left(center, scale, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def get_affine_transform_upper_left( + center: np.ndarray, + scale: float, + output_size: List[Tuple[int, int]], + inv: int = 0, +) -> np.ndarray: + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + scale = np.array([scale, scale], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + dst[0, :] = [0, 0] + if center[0] < center[1]: + src[1, :] = [scale[0], center[1]] + dst[1, :] = [output_size[0], 0] + else: + src[1, :] = [center[0], scale[0]] + dst[1, :] = [0, output_size[0]] + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + return trans + + +def get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def affine_transform(pt: np.ndarray, t: np.ndarray) -> np.ndarray: + new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] diff --git a/lineless_table_rec/utils_table_recover.py b/lineless_table_rec/utils/utils_table_recover.py similarity index 100% rename from lineless_table_rec/utils_table_recover.py rename to lineless_table_rec/utils/utils_table_recover.py diff --git a/tests/test_lineless_table_rec.py b/tests/test_lineless_table_rec.py index 1cf2c20..0cd36f1 100644 --- a/tests/test_lineless_table_rec.py +++ b/tests/test_lineless_table_rec.py @@ -10,7 +10,7 @@ sys.path.append(str(root_dir)) -from lineless_table_rec.utils_table_recover import * +from lineless_table_rec.utils.utils_table_recover import * from lineless_table_rec import LinelessTableRecognition test_file_dir = cur_dir / "test_files" diff --git a/tests/test_wired_table_line_util.py b/tests/test_wired_table_line_util.py index 9d22bc3..35fc2f3 100644 --- a/tests/test_wired_table_line_util.py +++ b/tests/test_wired_table_line_util.py @@ -1,6 +1,6 @@ import pytest import numpy as np -from wired_table_rec.utils_table_line_rec import ( +from wired_table_rec.utils.utils_table_line_rec import ( _order_points, calculate_center_rotate_angle, fit_line, diff --git a/tests/test_wired_table_rec.py b/tests/test_wired_table_rec.py index 01f8e76..6604a57 100644 --- a/tests/test_wired_table_rec.py +++ b/tests/test_wired_table_rec.py @@ -9,7 +9,7 @@ from rapidocr_onnxruntime import RapidOCR from wired_table_rec.utils import rescale_size -from wired_table_rec.utils_table_recover import ( +from wired_table_rec.utils.utils_table_recover import ( plot_html_table, is_single_axis_contained, gather_ocr_list_by_row, diff --git a/wired_table_rec/__init__.py b/wired_table_rec/__init__.py index ca11d54..f43cfe4 100644 --- a/wired_table_rec/__init__.py +++ b/wired_table_rec/__init__.py @@ -2,6 +2,6 @@ # @Author: SWHL # @Contact: liekkaskono@163.com from .main import WiredTableRecognition -from .utils_table_recover import vis_table +from wired_table_rec.utils.utils_table_recover import vis_table __all__ = ["WiredTableRecognition", "vis_table"] diff --git a/wired_table_rec/main.py b/wired_table_rec/main.py index 51bfd03..b470525 100644 --- a/wired_table_rec/main.py +++ b/wired_table_rec/main.py @@ -6,16 +6,19 @@ import logging import time import traceback +from dataclasses import dataclass, asdict +from enum import Enum from pathlib import Path -from typing import List, Optional, Tuple, Union, Dict, Any +from typing import List, Optional, Union, Dict, Any import numpy as np import cv2 -from wired_table_rec.table_line_rec import TableLineRecognition -from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus +from wired_table_rec.table_structure_cycle_center_net import TSRCycleCenterNet +from wired_table_rec.table_structure_unet import TSRUnet +from wired_table_rec.utils.download_model import DownloadModel from .table_recover import TableRecover -from .utils import InputType, LoadImage -from .utils_table_recover import ( +from .utils.utils import InputType, LoadImage +from wired_table_rec.utils.utils_table_recover import ( match_ocr_cell, plot_html_table, box_4_2_poly_to_box_4_1, @@ -24,20 +27,51 @@ gather_ocr_list_by_row, ) -cur_dir = Path(__file__).resolve().parent -default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx" -default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx" + +class ModelType(Enum): + CYCLE_CENTER_NET = "cycle_center_net" + UNET = "unet" + + +ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/" +KEY_TO_MODEL_URL = { + ModelType.CYCLE_CENTER_NET.value: f"{ROOT_URL}/cycle_center_net.onnx", + ModelType.UNET.value: f"{ROOT_URL}/unet.onnx", +} + + +@dataclass +class RapidTableInput: + model_type: Optional[str] = ModelType.UNET.value + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + + +@dataclass +class RapidTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None class WiredTableRecognition: - def __init__(self, table_model_path: Union[str, Path] = None, version="v2"): - self.load_img = LoadImage() - if version == "v2": - model_path = table_model_path if table_model_path else default_model_path_v2 - self.table_line_rec = TableLineRecognitionPlus(str(model_path)) + def __init__(self, config: RapidTableInput): + self.model_type = config.model_type + if self.model_type not in KEY_TO_MODEL_URL: + model_list = ",".join(KEY_TO_MODEL_URL) + raise ValueError( + f"{self.model_type} is not supported. The currently supported models are {model_list}." + ) + + config.model_path = self.get_model_path(config.model_type, config.model_path) + if self.model_type == ModelType.CYCLE_CENTER_NET.value: + self.table_structure = TSRCycleCenterNet(asdict(config)) else: - model_path = table_model_path if table_model_path else default_model_path - self.table_line_rec = TableLineRecognition(str(model_path)) + self.table_structure = TSRUnet(asdict(config)) + + self.load_img = LoadImage() self.table_recover = TableRecover() @@ -51,12 +85,7 @@ def __call__( img: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, **kwargs, - ) -> Tuple[str, float, Any, Any, Any]: - if self.ocr is None and ocr_result is None: - raise ValueError( - "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." - ) - + ) -> RapidTableOutput: s = time.perf_counter() rec_again = True need_ocr = True @@ -68,10 +97,10 @@ def __call__( col_threshold = kwargs.get("col_threshold", 15) row_threshold = kwargs.get("row_threshold", 10) img = self.load_img(img) - polygons, rotated_polygons = self.table_line_rec(img, **kwargs) + polygons, rotated_polygons = self.table_structure(img, **kwargs) if polygons is None: logging.warning("polygons is None.") - return "", 0.0, None, None, None + return RapidTableOutput("", None, None, 0.0) try: table_res, logi_points = self.table_recover( @@ -86,12 +115,11 @@ def __call__( sorted_polygons, idx_list = sorted_ocr_boxes( [box_4_2_poly_to_box_4_1(box) for box in polygons] ) - return ( + return RapidTableOutput( "", - time.perf_counter() - s, sorted_polygons, logi_points[idx_list], - [], + time.perf_counter() - s, ) if ocr_result is None and need_ocr: ocr_result, _ = self.ocr(img) @@ -108,25 +136,15 @@ def __call__( i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]] for i, t_box_ocr in enumerate(t_rec_ocr_list) } - table_str = plot_html_table(logi_points, cell_box_det_map) - ocr_boxes_res = [ - box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result - ] - sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res) - sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons] - sorted_logi_points = logi_points - table_elapse = time.perf_counter() - s + pred_html = plot_html_table(logi_points, cell_box_det_map) + polygons = polygons.reshape(-1, 8) + logi_points = np.array(logi_points) + elapse = time.perf_counter() - s except Exception: logging.warning(traceback.format_exc()) - return "", 0.0, None, None, None - return ( - table_str, - table_elapse, - sorted_polygons, - sorted_logi_points, - sorted_ocr_boxes_res, - ) + return RapidTableOutput("", None, None, 0.0) + return RapidTableOutput(pred_html, polygons, logi_points, elapse) def transform_res( self, @@ -224,6 +242,28 @@ def re_rec_high_precise( ] return cell_box_map + @staticmethod + def get_model_path( + model_type: str, model_path: Union[str, Path, None] + ) -> Union[str, Dict[str, str]]: + if model_path is not None: + return model_path + + model_url = KEY_TO_MODEL_URL.get(model_type, None) + if isinstance(model_url, str): + model_path = DownloadModel.download(model_url) + return model_path + + if isinstance(model_url, dict): + model_paths = {} + for k, url in model_url.items(): + model_paths[k] = DownloadModel.download( + url, save_model_name=f"{model_type}_{Path(url).name}" + ) + return model_paths + + raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.") + def main(): parser = argparse.ArgumentParser() diff --git a/wired_table_rec/table_line_rec.py b/wired_table_rec/table_structure_cycle_center_net.py similarity index 93% rename from wired_table_rec/table_line_rec.py rename to wired_table_rec/table_structure_cycle_center_net.py index b77f561..88bb4f4 100644 --- a/wired_table_rec/table_line_rec.py +++ b/wired_table_rec/table_structure_cycle_center_net.py @@ -6,8 +6,8 @@ import cv2 import numpy as np -from .utils import OrtInferSession -from .utils_table_line_rec import ( +from .utils.utils import OrtInferSession +from wired_table_rec.utils.utils_table_line_rec import ( bbox_decode, bbox_post_process, gbox_decode, @@ -16,7 +16,7 @@ group_bbox_by_gbox, nms, ) -from .utils_table_recover import ( +from wired_table_rec.utils.utils_table_recover import ( merge_adjacent_polys, sorted_ocr_boxes, box_4_2_poly_to_box_4_1, @@ -24,8 +24,8 @@ ) -class TableLineRecognition: - def __init__(self, model_path: Optional[str] = None): +class TSRCycleCenterNet: + def __init__(self, config: Dict): self.K = 1000 self.MK = 4000 self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3) @@ -34,7 +34,7 @@ def __init__(self, model_path: Optional[str] = None): self.inp_height = 1024 self.inp_width = 1024 - self.session = OrtInferSession(model_path) + self.session = OrtInferSession(config) def __call__( self, img: np.ndarray, **kwargs diff --git a/wired_table_rec/table_line_rec_plus.py b/wired_table_rec/table_structure_unet.py similarity index 96% rename from wired_table_rec/table_line_rec_plus.py rename to wired_table_rec/table_structure_unet.py index 38d282c..3eb02b6 100644 --- a/wired_table_rec/table_line_rec_plus.py +++ b/wired_table_rec/table_structure_unet.py @@ -5,22 +5,22 @@ import cv2 import numpy as np from skimage import measure -from wired_table_rec.utils import OrtInferSession, resize_img -from wired_table_rec.utils_table_line_rec import ( +from .utils.utils import OrtInferSession, resize_img +from .utils.utils_table_line_rec import ( get_table_line, final_adjust_lines, min_area_rect_box, draw_lines, adjust_lines, ) -from wired_table_rec.utils_table_recover import ( +from wired_table_rec.utils.utils_table_recover import ( sorted_ocr_boxes, box_4_2_poly_to_box_4_1, ) -class TableLineRecognitionPlus: - def __init__(self, model_path: Optional[str] = None): +class TSRUnet: + def __init__(self, config: Dict): self.K = 1000 self.MK = 4000 self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) @@ -28,7 +28,7 @@ def __init__(self, model_path: Optional[str] = None): self.inp_height = 1024 self.inp_width = 1024 - self.session = OrtInferSession(model_path) + self.session = OrtInferSession(config) def __call__( self, img: np.ndarray, **kwargs diff --git a/wired_table_rec/utils.py b/wired_table_rec/utils.py deleted file mode 100644 index d69676c..0000000 --- a/wired_table_rec/utils.py +++ /dev/null @@ -1,397 +0,0 @@ -# -*- encoding: utf-8 -*- -import math -import traceback -from io import BytesIO -from pathlib import Path -from typing import List, Union - -import cv2 -import numpy as np -from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions -from PIL import Image, UnidentifiedImageError - -root_dir = Path(__file__).resolve().parent -InputType = Union[str, np.ndarray, bytes, Path] - - -class OrtInferSession: - def __init__(self, model_path: Union[str, Path], num_threads: int = -1): - self.verify_exist(model_path) - - self.num_threads = num_threads - self._init_sess_opt() - - cpu_ep = "CPUExecutionProvider" - cpu_provider_options = { - "arena_extend_strategy": "kSameAsRequested", - } - EP_list = [(cpu_ep, cpu_provider_options)] - try: - self.session = InferenceSession( - str(model_path), sess_options=self.sess_opt, providers=EP_list - ) - except TypeError: - # 这里兼容ort 1.5.2 - self.session = InferenceSession(str(model_path), sess_options=self.sess_opt) - - def _init_sess_opt(self): - self.sess_opt = SessionOptions() - self.sess_opt.log_severity_level = 4 - self.sess_opt.enable_cpu_mem_arena = False - - if self.num_threads != -1: - self.sess_opt.intra_op_num_threads = self.num_threads - - self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - - def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: - input_dict = dict(zip(self.get_input_names(), input_content)) - try: - return self.session.run(None, input_dict) - except Exception as e: - error_info = traceback.format_exc() - raise ONNXRuntimeError(error_info) from e - - def get_input_names( - self, - ): - return [v.name for v in self.session.get_inputs()] - - def get_output_name(self, output_idx=0): - return self.session.get_outputs()[output_idx].name - - def get_metadata(self): - meta_dict = self.session.get_modelmeta().custom_metadata_map - return meta_dict - - @staticmethod - def verify_exist(model_path: Union[Path, str]): - if not isinstance(model_path, Path): - model_path = Path(model_path) - - if not model_path.exists(): - raise FileNotFoundError(f"{model_path} does not exist!") - - if not model_path.is_file(): - raise FileExistsError(f"{model_path} must be a file") - - -class ONNXRuntimeError(Exception): - pass - - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - img = self.load_img(img) - img = self.convert_img(img) - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = np.array(Image.open(img)) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = np.array(Image.open(BytesIO(img))) - return img - - if isinstance(img, np.ndarray): - return img - - raise LoadImageError(f"{type(img)} is not supported!") - - def convert_img(self, img: np.ndarray): - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3: - channel = img.shape[2] - if channel == 1: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if channel == 2: - return self.cvt_two_to_three(img) - - if channel == 4: - return self.cvt_four_to_three(img) - - if channel == 3: - return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - - raise LoadImageError( - f"The channel({channel}) of the img is not in [1, 2, 3, 4]" - ) - - raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → BGR""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def cvt_two_to_three(img: np.ndarray) -> np.ndarray: - """gray + alpha → BGR""" - img_gray = img[..., 0] - img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) - - img_alpha = img[..., 1] - not_a = cv2.bitwise_not(img_alpha) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass - - -# Pillow >=v9.1.0 use a slightly different naming scheme for filters. -# Set pillow_interp_codes according to the naming scheme used. -if Image is not None: - if hasattr(Image, "Resampling"): - pillow_interp_codes = { - "nearest": Image.Resampling.NEAREST, - "bilinear": Image.Resampling.BILINEAR, - "bicubic": Image.Resampling.BICUBIC, - "box": Image.Resampling.BOX, - "lanczos": Image.Resampling.LANCZOS, - "hamming": Image.Resampling.HAMMING, - } - else: - pillow_interp_codes = { - "nearest": Image.NEAREST, - "bilinear": Image.BILINEAR, - "bicubic": Image.BICUBIC, - "box": Image.BOX, - "lanczos": Image.LANCZOS, - "hamming": Image.HAMMING, - } - -cv2_interp_codes = { - "nearest": cv2.INTER_NEAREST, - "bilinear": cv2.INTER_LINEAR, - "bicubic": cv2.INTER_CUBIC, - "area": cv2.INTER_AREA, - "lanczos": cv2.INTER_LANCZOS4, -} - - -def resize_img(img, scale, keep_ratio=True): - if keep_ratio: - # 缩小使用area更保真 - if min(img.shape[:2]) > min(scale): - interpolation = "area" - else: - interpolation = "bicubic" # bilinear - img_new, scale_factor = imrescale( - img, scale, return_scale=True, interpolation=interpolation - ) - # the w_scale and h_scale has minor difference - # a real fix should be done in the mmcv.imrescale in the future - new_h, new_w = img_new.shape[:2] - h, w = img.shape[:2] - w_scale = new_w / w - h_scale = new_h / h - else: - img_new, w_scale, h_scale = imresize(img, scale, return_scale=True) - return img_new, w_scale, h_scale - - -def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None): - """Resize image while keeping the aspect ratio. - - Args: - img (ndarray): The input image. - scale (float | tuple[int]): The scaling factor or maximum size. - If it is a float number, then the image will be rescaled by this - factor, else if it is a tuple of 2 integers, then the image will - be rescaled as large as possible within the scale. - return_scale (bool): Whether to return the scaling factor besides the - rescaled image. - interpolation (str): Same as :func:`resize`. - backend (str | None): Same as :func:`resize`. - - Returns: - ndarray: The rescaled image. - """ - h, w = img.shape[:2] - new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) - rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend) - if return_scale: - return rescaled_img, scale_factor - else: - return rescaled_img - - -def imresize( - img, size, return_scale=False, interpolation="bilinear", out=None, backend=None -): - """Resize image to a given size. - - Args: - img (ndarray): The input image. - size (tuple[int]): Target size (w, h). - return_scale (bool): Whether to return `w_scale` and `h_scale`. - interpolation (str): Interpolation method, accepted values are - "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' - backend, "nearest", "bilinear" for 'pillow' backend. - out (ndarray): The output destination. - backend (str | None): The image resize backend type. Options are `cv2`, - `pillow`, `None`. If backend is None, the global imread_backend - specified by ``mmcv.use_backend()`` will be used. Default: None. - - Returns: - tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or - `resized_img`. - """ - h, w = img.shape[:2] - if backend is None: - backend = "cv2" - if backend not in ["cv2", "pillow"]: - raise ValueError( - f"backend: {backend} is not supported for resize." - f"Supported backends are 'cv2', 'pillow'" - ) - - if backend == "pillow": - assert img.dtype == np.uint8, "Pillow backend only support uint8 type" - pil_image = Image.fromarray(img) - pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) - resized_img = np.array(pil_image) - else: - resized_img = cv2.resize( - img, size, dst=out, interpolation=cv2_interp_codes[interpolation] - ) - if not return_scale: - return resized_img - else: - w_scale = size[0] / w - h_scale = size[1] / h - return resized_img, w_scale, h_scale - - -def rescale_size(old_size, scale, return_scale=False): - """Calculate the new size to be rescaled to. - - Args: - old_size (tuple[int]): The old size (w, h) of image. - scale (float | tuple[int]): The scaling factor or maximum size. - If it is a float number, then the image will be rescaled by this - factor, else if it is a tuple of 2 integers, then the image will - be rescaled as large as possible within the scale. - return_scale (bool): Whether to return the scaling factor besides the - rescaled image size. - - Returns: - tuple[int]: The new rescaled image size. - """ - w, h = old_size - if isinstance(scale, (float, int)): - if scale <= 0: - raise ValueError(f"Invalid scale {scale}, must be positive.") - scale_factor = scale - elif isinstance(scale, tuple): - max_long_edge = max(scale) - max_short_edge = min(scale) - scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) - else: - raise TypeError( - f"Scale must be a number or tuple of int, but got {type(scale)}" - ) - - new_size = _scale_size((w, h), scale_factor) - - if return_scale: - return new_size, scale_factor - else: - return new_size - - -def _scale_size(size, scale): - """Rescale a size by a ratio. - - Args: - size (tuple[int]): (w, h). - scale (float | tuple(float)): Scaling factor. - - Returns: - tuple[int]: scaled size. - """ - if isinstance(scale, (float, int)): - scale = (scale, scale) - w, h = size - return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) - - -class ImageOrientationCorrector: - """ - 对图片小角度(-90 - + 90度进行修正) - """ - - def __init__(self): - self.img_loader = LoadImage() - - def __call__(self, img: InputType): - img = self.img_loader(img) - # 取灰度 - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - # 二值化 - gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] - # 边缘检测 - edges = cv2.Canny(gray, 100, 250, apertureSize=3) - # 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322 - lines = cv2.HoughLines(edges, 1, np.pi / 180, 0) - for rho, theta in lines[0]: - a = np.cos(theta) - b = np.sin(theta) - x0 = a * rho - y0 = b * rho - x1 = int(x0 + 1000 * (-b)) - y1 = int(y0 + 1000 * (a)) - x2 = int(x0 - 1000 * (-b)) - y2 = int(y0 - 1000 * (a)) - if x1 == x2 or y1 == y2: - return img - else: - t = float(y2 - y1) / (x2 - x1) - # 得到角度后 - rotate_angle = math.degrees(math.atan(t)) - if rotate_angle > 45: - rotate_angle = -90 + rotate_angle - elif rotate_angle < -45: - rotate_angle = 90 + rotate_angle - # 旋转图像 - (h, w) = img.shape[:2] - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0) - return cv2.warpAffine(img, M, (w, h)) diff --git a/wired_table_rec/utils/__init__.py b/wired_table_rec/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/wired_table_rec/utils/download_model.py b/wired_table_rec/utils/download_model.py new file mode 100644 index 0000000..adedb5d --- /dev/null +++ b/wired_table_rec/utils/download_model.py @@ -0,0 +1,67 @@ +import io +from pathlib import Path +from typing import Optional, Union + +import requests +from tqdm import tqdm + +from .logger import get_logger + +logger = get_logger("DownloadModel") + +PROJECT_DIR = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = PROJECT_DIR / "models" + + +class DownloadModel: + @classmethod + def download( + cls, + model_full_url: Union[str, Path], + save_dir: Union[str, Path, None] = None, + save_model_name: Optional[str] = None, + ) -> str: + if save_dir is None: + save_dir = DEFAULT_MODEL_DIR + + save_dir.mkdir(parents=True, exist_ok=True) + + if save_model_name is None: + save_model_name = Path(model_full_url).name + + save_file_path = save_dir / save_model_name + if save_file_path.exists(): + logger.debug("%s already exists", save_file_path) + return str(save_file_path) + + try: + logger.info("Download %s to %s", model_full_url, save_dir) + file = cls.download_as_bytes_with_progress(model_full_url, save_model_name) + cls.save_file(save_file_path, file) + except Exception as exc: + raise DownloadModelError from exc + return str(save_file_path) + + @staticmethod + def download_as_bytes_with_progress( + url: Union[str, Path], name: Optional[str] = None + ) -> bytes: + resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) + total = int(resp.headers.get("content-length", 0)) + bio = io.BytesIO() + with tqdm( + desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 + ) as pbar: + for chunk in resp.iter_content(chunk_size=65536): + pbar.update(len(chunk)) + bio.write(chunk) + return bio.getvalue() + + @staticmethod + def save_file(save_path: Union[str, Path], file: bytes): + with open(save_path, "wb") as f: + f.write(file) + + +class DownloadModelError(Exception): + pass diff --git a/wired_table_rec/utils/logger.py b/wired_table_rec/utils/logger.py new file mode 100644 index 0000000..2950987 --- /dev/null +++ b/wired_table_rec/utils/logger.py @@ -0,0 +1,21 @@ +# -*- encoding: utf-8 -*- +# @Author: Jocker1212 +# @Contact: xinyijianggo@gmail.com +import logging +from functools import lru_cache + + +@lru_cache(maxsize=32) +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/wired_table_rec/utils/utils.py b/wired_table_rec/utils/utils.py new file mode 100644 index 0000000..cdc1f44 --- /dev/null +++ b/wired_table_rec/utils/utils.py @@ -0,0 +1,694 @@ +# -*- encoding: utf-8 -*- +import math +import os +import platform +import traceback +from enum import Enum +from io import BytesIO +from pathlib import Path +from typing import List, Union, Dict, Any, Tuple, Optional + +import cv2 +import numpy as np +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) +from PIL import Image, UnidentifiedImageError + +from wired_table_rec.utils.logger import get_logger + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path] + + +class EP(Enum): + CPU_EP = "CPUExecutionProvider" + CUDA_EP = "CUDAExecutionProvider" + DIRECTML_EP = "DmlExecutionProvider" + + +class OrtInferSession: + def __init__(self, config: Dict[str, Any]): + self.logger = get_logger("OrtInferSession") + + model_path = config.get("model_path", None) + self._verify_model(model_path) + + self.cfg_use_cuda = config.get("use_cuda", None) + self.cfg_use_dml = config.get("use_dml", None) + + self.had_providers: List[str] = get_available_providers() + EP_list = self._get_ep_list() + + sess_opt = self._init_sess_opts(config) + self.session = InferenceSession( + model_path, + sess_options=sess_opt, + providers=EP_list, + ) + self._verify_providers() + + @staticmethod + def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: + sess_opt = SessionOptions() + sess_opt.log_severity_level = 4 + sess_opt.enable_cpu_mem_arena = False + sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + cpu_nums = os.cpu_count() + intra_op_num_threads = config.get("intra_op_num_threads", -1) + if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: + sess_opt.intra_op_num_threads = intra_op_num_threads + + inter_op_num_threads = config.get("inter_op_num_threads", -1) + if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: + sess_opt.inter_op_num_threads = inter_op_num_threads + + return sess_opt + + def get_metadata(self, key: str = "character") -> list: + meta_dict = self.session.get_modelmeta().custom_metadata_map + content_list = meta_dict[key].splitlines() + return content_list + + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: + cpu_provider_opts = { + "arena_extend_strategy": "kSameAsRequested", + } + EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] + + cuda_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + } + self.use_cuda = self._check_cuda() + if self.use_cuda: + EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) + + self.use_directml = self._check_dml() + if self.use_directml: + self.logger.info( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) + EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) + return EP_list + + def _check_cuda(self) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.CUDA_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") + self.logger.info( + "(For reference only) If you want to use GPU acceleration, you must do:" + ) + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + EP.CUDA_EP.value, + ) + return False + + def _check_dml(self) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", + cur_os, + self.had_providers[0], + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", + cur_window_version, + self.had_providers[0], + ) + return False + + if EP.DIRECTML_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.DIRECTML_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + EP.DIRECTML_EP.value, + ) + return False + + def _verify_providers(self): + session_providers = self.session.get_providers() + first_provider = session_providers[0] + + if self.use_cuda and first_provider != EP.CUDA_EP.value: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + EP.CUDA_EP.value, + first_provider, + ) + + if self.use_directml and first_provider != EP.DIRECTML_EP.value: + self.logger.warning( + "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", + EP.DIRECTML_EP.value, + first_provider, + ) + + def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), input_content)) + try: + return self.session.run(None, input_dict) + except Exception as e: + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e + + def get_input_names(self) -> List[str]: + return [v.name for v in self.session.get_inputs()] + + def get_output_names(self) -> List[str]: + return [v.name for v in self.session.get_outputs()] + + def get_character_list(self, key: str = "character") -> List[str]: + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): + return True + return False + + @staticmethod + def _verify_model(model_path: Union[str, Path, None]): + if model_path is None: + raise ValueError("model_path is None!") + + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +class ONNXRuntimeError(Exception): + pass + + +class LoadImage: + def __init__( + self, + ): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + img = self.load_img(img) + img = self.convert_img(img) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = np.array(Image.open(img)) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = np.array(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + raise LoadImageError(f"{type(img)} is not supported!") + + def convert_img(self, img: np.ndarray): + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 4: + return self.cvt_four_to_three(img) + + if channel == 3: + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass + + +# Pillow >=v9.1.0 use a slightly different naming scheme for filters. +# Set pillow_interp_codes according to the naming scheme used. +if Image is not None: + if hasattr(Image, "Resampling"): + pillow_interp_codes = { + "nearest": Image.Resampling.NEAREST, + "bilinear": Image.Resampling.BILINEAR, + "bicubic": Image.Resampling.BICUBIC, + "box": Image.Resampling.BOX, + "lanczos": Image.Resampling.LANCZOS, + "hamming": Image.Resampling.HAMMING, + } + else: + pillow_interp_codes = { + "nearest": Image.NEAREST, + "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC, + "box": Image.BOX, + "lanczos": Image.LANCZOS, + "hamming": Image.HAMMING, + } + +cv2_interp_codes = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, +} + + +def resize_img(img, scale, keep_ratio=True): + if keep_ratio: + # 缩小使用area更保真 + if min(img.shape[:2]) > min(scale): + interpolation = "area" + else: + interpolation = "bicubic" # bilinear + img_new, scale_factor = imrescale( + img, scale, return_scale=True, interpolation=interpolation + ) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img_new.shape[:2] + h, w = img.shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img_new, w_scale, h_scale = imresize(img, scale, return_scale=True) + return img_new, w_scale, h_scale + + +def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None): + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +def imresize( + img, size, return_scale=False, interpolation="bilinear", out=None, backend=None +): + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if backend is None: + backend = "cv2" + if backend not in ["cv2", "pillow"]: + raise ValueError( + f"backend: {backend} is not supported for resize." + f"Supported backends are 'cv2', 'pillow'" + ) + + if backend == "pillow": + assert img.dtype == np.uint8, "Pillow backend only support uint8 type" + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize( + img, size, dst=out, interpolation=cv2_interp_codes[interpolation] + ) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def rescale_size(old_size, scale, return_scale=False): + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f"Invalid scale {scale}, must be positive.") + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) + else: + raise TypeError( + f"Scale must be a number or tuple of int, but got {type(scale)}" + ) + + new_size = _scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def _scale_size(size, scale): + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) + + +class ImageOrientationCorrector: + """ + 对图片小角度(-90 - + 90度进行修正) + """ + + def __init__(self): + self.img_loader = LoadImage() + + def __call__(self, img: InputType): + img = self.img_loader(img) + # 取灰度 + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # 二值化 + gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] + # 边缘检测 + edges = cv2.Canny(gray, 100, 250, apertureSize=3) + # 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322 + lines = cv2.HoughLines(edges, 1, np.pi / 180, 0) + for rho, theta in lines[0]: + a = np.cos(theta) + b = np.sin(theta) + x0 = a * rho + y0 = b * rho + x1 = int(x0 + 1000 * (-b)) + y1 = int(y0 + 1000 * (a)) + x2 = int(x0 - 1000 * (-b)) + y2 = int(y0 - 1000 * (a)) + if x1 == x2 or y1 == y2: + return img + else: + t = float(y2 - y1) / (x2 - x1) + # 得到角度后 + rotate_angle = math.degrees(math.atan(t)) + if rotate_angle > 45: + rotate_angle = -90 + rotate_angle + elif rotate_angle < -45: + rotate_angle = 90 + rotate_angle + # 旋转图像 + (h, w) = img.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) + + +class VisTable: + def __init__(self): + self.load_img = LoadImage() + + def __call__( + self, + img_path: Union[str, Path], + table_results, + save_html_path: Optional[Union[str, Path]] = None, + save_drawed_path: Optional[Union[str, Path]] = None, + save_logic_path: Optional[Union[str, Path]] = None, + ): + if save_html_path: + html_with_border = self.insert_border_style(table_results.pred_html) + self.save_html(save_html_path, html_with_border) + + table_cell_bboxes = table_results.cell_bboxes + table_logic_points = table_results.logic_points + if table_cell_bboxes is None: + return None + + img = self.load_img(img_path) + + dims_bboxes = table_cell_bboxes.shape[1] + if dims_bboxes == 4: + drawed_img = self.draw_rectangle(img, table_cell_bboxes) + elif dims_bboxes == 8: + drawed_img = self.draw_polylines(img, table_cell_bboxes) + else: + raise ValueError("Shape of table bounding boxes is not between in 4 or 8.") + + if save_drawed_path: + self.save_img(save_drawed_path, drawed_img) + + if save_logic_path: + polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes] + self.plot_rec_box_with_logic_info( + img_path, save_logic_path, table_logic_points, polygons + ) + return drawed_img + + def insert_border_style(self, table_html_str: str): + style_res = """""" + + prefix_table, suffix_table = table_html_str.split("") + html_with_border = f"{prefix_table}{style_res}{suffix_table}" + return html_with_border + + def plot_rec_box_with_logic_info( + self, img_path, output_path, logic_points, sorted_polygons + ): + """ + :param img_path + :param output_path + :param logic_points: [row_start,row_end,col_start,col_end] + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 读取原图 + img = cv2.imread(img_path) + img = cv2.copyMakeBorder( + img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + # 绘制 polygons 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 0.9 # 原先是0.5 + thickness = 1 # 原先是1 + logic_point = logic_points[idx] + cv2.putText( + img, + f"row: {logic_point[0]}-{logic_point[1]}", + (x0 + 3, y0 + 8), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + cv2.putText( + img, + f"col: {logic_point[2]}-{logic_point[3]}", + (x0 + 3, y0 + 18), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + self.save_img(output_path, img) + + @staticmethod + def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: + img_copy = img.copy() + for box in boxes.astype(int): + x1, y1, x2, y2 = box + cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2) + return img_copy + + @staticmethod + def draw_polylines(img: np.ndarray, points) -> np.ndarray: + img_copy = img.copy() + for point in points.astype(int): + point = point.reshape(4, 2) + cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2) + return img_copy + + @staticmethod + def save_img(save_path: Union[str, Path], img: np.ndarray): + cv2.imwrite(str(save_path), img) + + @staticmethod + def save_html(save_path: Union[str, Path], html: str): + with open(save_path, "w", encoding="utf-8") as f: + f.write(html) diff --git a/wired_table_rec/utils_table_line_rec.py b/wired_table_rec/utils/utils_table_line_rec.py similarity index 100% rename from wired_table_rec/utils_table_line_rec.py rename to wired_table_rec/utils/utils_table_line_rec.py diff --git a/wired_table_rec/utils_table_recover.py b/wired_table_rec/utils/utils_table_recover.py similarity index 87% rename from wired_table_rec/utils_table_recover.py rename to wired_table_rec/utils/utils_table_recover.py index 235c39e..2726e57 100644 --- a/wired_table_rec/utils_table_recover.py +++ b/wired_table_rec/utils/utils_table_recover.py @@ -1,7 +1,6 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import os import random from typing import Any, Dict, List, Union, Set, Tuple @@ -240,54 +239,6 @@ def sorted_ocr_boxes( return _boxes, indices -def plot_rec_box_with_logic_info(img_path, output_path, logic_points, sorted_polygons): - """ - :param img_path - :param output_path - :param logic_points: [row_start,row_end,col_start,col_end] - :param sorted_polygons: [xmin,ymin,xmax,ymax] - :return: - """ - # 读取原图 - img = cv2.imread(img_path) - img = cv2.copyMakeBorder( - img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] - ) - # 绘制 polygons 矩形 - for idx, polygon in enumerate(sorted_polygons): - x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] - x0 = round(x0) - y0 = round(y0) - x1 = round(x1) - y1 = round(y1) - cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) - # 增大字体大小和线宽 - font_scale = 0.9 # 原先是0.5 - thickness = 1 # 原先是1 - logic_point = logic_points[idx] - cv2.putText( - img, - f"row: {logic_point[0]}-{logic_point[1]}", - (x0 + 3, y0 + 8), - cv2.FONT_HERSHEY_PLAIN, - font_scale, - (0, 0, 255), - thickness, - ) - cv2.putText( - img, - f"col: {logic_point[2]}-{logic_point[3]}", - (x0 + 3, y0 + 18), - cv2.FONT_HERSHEY_PLAIN, - font_scale, - (0, 0, 255), - thickness, - ) - os.makedirs(os.path.dirname(output_path), exist_ok=True) - # 保存绘制后的图像 - cv2.imwrite(output_path, img) - - def trans_char_ocr_res(ocr_res): word_result = [] for res in ocr_res: @@ -301,44 +252,6 @@ def trans_char_ocr_res(ocr_res): return word_result -def plot_rec_box(img_path, output_path, sorted_polygons): - """ - :param img_path - :param output_path - :param sorted_polygons: [xmin,ymin,xmax,ymax] - :return: - """ - # 处理ocr_res - img = cv2.imread(img_path) - img = cv2.copyMakeBorder( - img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] - ) - # 绘制 ocr_res 矩形 - for idx, polygon in enumerate(sorted_polygons): - x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] - x0 = round(x0) - y0 = round(y0) - x1 = round(x1) - y1 = round(y1) - cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) - # 增大字体大小和线宽 - font_scale = 0.9 # 原先是0.5 - thickness = 1 # 原先是1 - - cv2.putText( - img, - str(idx), - (x0 + 5, y0 + 5), - cv2.FONT_HERSHEY_PLAIN, - font_scale, - (0, 0, 255), - thickness, - ) - os.makedirs(os.path.dirname(output_path), exist_ok=True) - # 保存绘制后的图像 - cv2.imwrite(output_path, img) - - def box_4_1_poly_to_box_4_2(poly_box: Union[list, np.ndarray]) -> List[List[float]]: xmin, ymin, xmax, ymax = tuple(poly_box) return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] From 6cda006c15db3f4f340e38abc18abb9541c9f108 Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sat, 8 Mar 2025 15:03:47 +0800 Subject: [PATCH 2/8] feat: add paddle cls for table cls --- demo_table_cls.py | 11 ++-- lineless_table_rec/main.py | 1 + table_cls/main.py | 99 +++++++++++++++++++++++++++---- table_cls/utils/__init__.py | 0 table_cls/utils/download_model.py | 67 +++++++++++++++++++++ table_cls/utils/logger.py | 21 +++++++ table_cls/{ => utils}/utils.py | 0 7 files changed, 184 insertions(+), 15 deletions(-) create mode 100644 table_cls/utils/__init__.py create mode 100644 table_cls/utils/download_model.py create mode 100644 table_cls/utils/logger.py rename table_cls/{ => utils}/utils.py (100%) diff --git a/demo_table_cls.py b/demo_table_cls.py index 2300126..321ff9d 100644 --- a/demo_table_cls.py +++ b/demo_table_cls.py @@ -1,8 +1,9 @@ # -*- encoding: utf-8 -*- from table_cls import TableCls -table_cls = TableCls() -img_path = "tests/test_files/table_cls/lineless_table.png" -cls_str, elapse = table_cls(img_path) -print(cls_str) -print(elapse) +if __name__ == "__main__": + table_cls = TableCls(model_type="yolox") + img_path = "tests/test_files/table_cls/lineless_table_2.png" + cls_str, elapse = table_cls(img_path) + print(cls_str) + print(elapse) diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index c7514c9..ef4261c 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -84,6 +84,7 @@ def __call__( need_ocr = True if kwargs: rec_again = kwargs.get("rec_again", True) + need_ocr = kwargs.get("need_ocr", True) img = self.load_img(content) try: polygons, logi_points = self.table_structure(img) diff --git a/table_cls/main.py b/table_cls/main.py index 179c0cc..4fdb8cc 100644 --- a/table_cls/main.py +++ b/table_cls/main.py @@ -1,26 +1,42 @@ import time +from enum import Enum from pathlib import Path +from typing import Union, Dict import cv2 import numpy as np from PIL import Image -from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop +from .utils.download_model import DownloadModel +from .utils.utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop -cur_dir = Path(__file__).resolve().parent -q_cls_model_path = cur_dir / "models" / "table_cls.onnx" -yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx" -yolo_cls_x_model_path = cur_dir / "models" / "yolo_cls_x.onnx" + +class ModelType(Enum): + YOLO_CLS_X = "yolox" + YOLO_CLS = "yolo" + PADDLE_CLS = "paddle" + Q_CLS = "q" + + +ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/" +KEY_TO_MODEL_URL = { + ModelType.YOLO_CLS_X.value: f"{ROOT_URL}/table_cls/yolo_cls_x.onnx", + ModelType.YOLO_CLS.value: f"{ROOT_URL}/table_cls/yolo_cls.onnx", + ModelType.PADDLE_CLS.value: f"{ROOT_URL}/table_cls/paddle_cls.onnx", + ModelType.Q_CLS.value: f"{ROOT_URL}/table_cls/q_cls.onnx", +} class TableCls: - def __init__(self, model_type="yolo", model_path=yolo_cls_model_path): - if model_type == "yolo": + def __init__(self, model_type=ModelType.YOLO_CLS.value, model_path=None): + model_path = self.get_model_path(model_type, model_path) + if model_type == ModelType.YOLO_CLS.value: + self.table_engine = YoloCls(model_path) + elif model_type == ModelType.YOLO_CLS_X.value: self.table_engine = YoloCls(model_path) - elif model_type == "yolox": - self.table_engine = YoloCls(yolo_cls_x_model_path) + elif model_type == ModelType.PADDLE_CLS.value: + self.table_engine = PaddleCls(model_path) else: - model_path = q_cls_model_path self.table_engine = QanythingCls(model_path) self.load_img = LoadImage() @@ -32,6 +48,69 @@ def __call__(self, content: InputType): table_elapse = time.perf_counter() - ss return predict_cla, table_elapse + @staticmethod + def get_model_path( + model_type: str, model_path: Union[str, Path, None] + ) -> Union[str, Dict[str, str]]: + if model_path is not None: + return model_path + + model_url = KEY_TO_MODEL_URL.get(model_type, None) + if isinstance(model_url, str): + model_path = DownloadModel.download(model_url) + return model_path + + if isinstance(model_url, dict): + model_paths = {} + for k, url in model_url.items(): + model_paths[k] = DownloadModel.download( + url, save_model_name=f"{model_type}_{Path(url).name}" + ) + return model_paths + + raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.") + + +class PaddleCls: + def __init__(self, model_path): + self.table_cls = OrtInferSession(model_path) + self.inp_h = 224 + self.inp_w = 224 + self.resize_short = 256 + self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) + self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) + self.cls = {0: "wired", 1: "wireless"} + + def preprocess(self, img): + # short resize + img_h, img_w = img.shape[:2] + percent = float(self.resize_short) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + img = cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_LANCZOS4) + # center crop + img_h, img_w = img.shape[:2] + w_start = (img_w - self.inp_w) // 2 + h_start = (img_h - self.inp_h) // 2 + w_end = w_start + self.inp_w + h_end = h_start + self.inp_h + img = img[h_start:h_end, w_start:w_end, :] + # normalize + img = np.array(img, dtype=np.float32) / 255.0 + img -= self.mean + img /= self.std + # HWC to CHW + img = img.transpose(2, 0, 1) + # Add batch dimension, only one image + img = np.expand_dims(img, axis=0) + return img + + def __call__(self, img): + pred_output = self.table_cls(img)[0] + pred_idxs = list(np.argmax(pred_output, axis=1)) + predict_cla = max(set(pred_idxs), key=pred_idxs.count) + return self.cls[predict_cla] + class QanythingCls: def __init__(self, model_path): diff --git a/table_cls/utils/__init__.py b/table_cls/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/table_cls/utils/download_model.py b/table_cls/utils/download_model.py new file mode 100644 index 0000000..adedb5d --- /dev/null +++ b/table_cls/utils/download_model.py @@ -0,0 +1,67 @@ +import io +from pathlib import Path +from typing import Optional, Union + +import requests +from tqdm import tqdm + +from .logger import get_logger + +logger = get_logger("DownloadModel") + +PROJECT_DIR = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = PROJECT_DIR / "models" + + +class DownloadModel: + @classmethod + def download( + cls, + model_full_url: Union[str, Path], + save_dir: Union[str, Path, None] = None, + save_model_name: Optional[str] = None, + ) -> str: + if save_dir is None: + save_dir = DEFAULT_MODEL_DIR + + save_dir.mkdir(parents=True, exist_ok=True) + + if save_model_name is None: + save_model_name = Path(model_full_url).name + + save_file_path = save_dir / save_model_name + if save_file_path.exists(): + logger.debug("%s already exists", save_file_path) + return str(save_file_path) + + try: + logger.info("Download %s to %s", model_full_url, save_dir) + file = cls.download_as_bytes_with_progress(model_full_url, save_model_name) + cls.save_file(save_file_path, file) + except Exception as exc: + raise DownloadModelError from exc + return str(save_file_path) + + @staticmethod + def download_as_bytes_with_progress( + url: Union[str, Path], name: Optional[str] = None + ) -> bytes: + resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) + total = int(resp.headers.get("content-length", 0)) + bio = io.BytesIO() + with tqdm( + desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 + ) as pbar: + for chunk in resp.iter_content(chunk_size=65536): + pbar.update(len(chunk)) + bio.write(chunk) + return bio.getvalue() + + @staticmethod + def save_file(save_path: Union[str, Path], file: bytes): + with open(save_path, "wb") as f: + f.write(file) + + +class DownloadModelError(Exception): + pass diff --git a/table_cls/utils/logger.py b/table_cls/utils/logger.py new file mode 100644 index 0000000..2950987 --- /dev/null +++ b/table_cls/utils/logger.py @@ -0,0 +1,21 @@ +# -*- encoding: utf-8 -*- +# @Author: Jocker1212 +# @Contact: xinyijianggo@gmail.com +import logging +from functools import lru_cache + + +@lru_cache(maxsize=32) +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/table_cls/utils.py b/table_cls/utils/utils.py similarity index 100% rename from table_cls/utils.py rename to table_cls/utils/utils.py From 7fd2549ebbaf85e59f58170c546c7ce19daaf087 Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sat, 8 Mar 2025 15:04:06 +0800 Subject: [PATCH 3/8] test: fix test --- tests/test_lineless_table_rec.py | 46 +++++++++++++--------- tests/test_wired_table_rec.py | 67 ++++++++++++++++++++------------ 2 files changed, 71 insertions(+), 42 deletions(-) diff --git a/tests/test_lineless_table_rec.py b/tests/test_lineless_table_rec.py index 0cd36f1..3eb42d7 100644 --- a/tests/test_lineless_table_rec.py +++ b/tests/test_lineless_table_rec.py @@ -5,6 +5,8 @@ from pathlib import Path import pytest +from lineless_table_rec.main import RapidTableInput, ModelType + cur_dir = Path(__file__).resolve().parent root_dir = cur_dir.parent @@ -14,8 +16,8 @@ from lineless_table_rec import LinelessTableRecognition test_file_dir = cur_dir / "test_files" - -table_recog = LinelessTableRecognition() +input_args = RapidTableInput(model_type=ModelType.LORE.value) +table_recog = LinelessTableRecognition(input_args) @pytest.mark.parametrize( @@ -27,12 +29,15 @@ ) def test_input_normal(img_path, table_str_len, td_nums): img_path = test_file_dir / img_path - img = cv2.imread(str(img_path)) - table_str, *_ = table_recog(img) + table_results = table_recog(str(img_path)) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) - assert len(table_str) >= table_str_len - assert table_str.count("td") == td_nums + assert len(table_html_str) >= table_str_len + assert table_html_str.count("td") == td_nums @pytest.mark.parametrize( @@ -254,12 +259,15 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html): ) def test_no_rec_again(img_path, table_str_len, td_nums): img_path = test_file_dir / img_path - img = cv2.imread(str(img_path)) - table_str, *_ = table_recog(img, rec_again=False) + table_results = table_recog(str(img_path), rec_again=False) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) - assert len(table_str) >= table_str_len - assert table_str.count("td") == td_nums + assert len(table_html_str) >= table_str_len + assert table_html_str.count("td") == td_nums @pytest.mark.parametrize( @@ -271,12 +279,14 @@ def test_no_rec_again(img_path, table_str_len, td_nums): ) def test_no_ocr(img_path, html_output, points_len): img_path = test_file_dir / img_path - - html, elasp, polygons, logic_points, ocr_res = table_recog( - str(img_path), need_ocr=False + table_results = table_recog(str(img_path), need_ocr=False) + table_html_str, table_cell_bboxes, table_logic_points = ( + table_results.pred_html, + table_results.cell_bboxes, + table_results.logic_points, ) - assert len(ocr_res) == 0 - assert len(polygons) > points_len - assert len(logic_points) > points_len - assert len(polygons) == len(logic_points) - assert html == html_output + + assert len(table_cell_bboxes) > points_len + assert len(table_logic_points) > points_len + assert len(table_cell_bboxes) == len(table_logic_points) + assert table_html_str == html_output diff --git a/tests/test_wired_table_rec.py b/tests/test_wired_table_rec.py index 6604a57..55d8818 100644 --- a/tests/test_wired_table_rec.py +++ b/tests/test_wired_table_rec.py @@ -8,7 +8,8 @@ from bs4 import BeautifulSoup from rapidocr_onnxruntime import RapidOCR -from wired_table_rec.utils import rescale_size +from wired_table_rec.main import RapidTableInput, ModelType +from wired_table_rec.utils.utils import rescale_size from wired_table_rec.utils.utils_table_recover import ( plot_html_table, is_single_axis_contained, @@ -25,8 +26,8 @@ from wired_table_rec import WiredTableRecognition test_file_dir = cur_dir / "test_files" / "wired" - -table_recog = WiredTableRecognition() +input_args = RapidTableInput(model_type=ModelType.UNET.value) +table_recog = WiredTableRecognition(input_args) ocr_engine = RapidOCR() @@ -40,9 +41,13 @@ def get_td_nums(html: str) -> int: def test_squeeze_bug(): img_path = test_file_dir / "squeeze_error.jpeg" - ocr_result, _ = ocr_engine(img_path) - table_str, *_ = table_recog(str(img_path), ocr_result) - td_nums = get_td_nums(table_str) + ocr_result, _ = ocr_engine(str(img_path)) + table_results = table_recog(str(img_path)) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) + td_nums = get_td_nums(table_html_str) assert td_nums >= 160 @@ -58,9 +63,13 @@ def test_squeeze_bug(): def test_input_normal(img_path, gt_td_nums, gt2): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(img_path) - table_str, *_ = table_recog(str(img_path), ocr_result) - td_nums = get_td_nums(table_str) + ocr_result, _ = ocr_engine(str(img_path)) + table_results = table_recog(str(img_path)) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) + td_nums = get_td_nums(table_html_str) assert td_nums >= gt_td_nums @@ -74,9 +83,13 @@ def test_input_normal(img_path, gt_td_nums, gt2): def test_enhance_box_line(img_path, gt_td_nums): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(img_path) - table_str, *_ = table_recog(str(img_path), ocr_result, enhance_box_line=False) - td_nums = get_td_nums(table_str) + ocr_result, _ = ocr_engine(str(img_path)) + table_results = table_recog(str(img_path), enhance_box_line=False) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) + td_nums = get_td_nums(table_html_str) assert td_nums <= gt_td_nums @@ -291,10 +304,13 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html): def test_no_rec_again(img_path, gt_td_nums, gt2): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(img_path) - table_str, *_ = table_recog(str(img_path), ocr_result, rec_again=False) - td_nums = get_td_nums(table_str) - + ocr_result, _ = ocr_engine(str(img_path)) + table_results = table_recog(str(img_path), rec_again=False) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) + td_nums = get_td_nums(table_html_str) assert td_nums >= gt_td_nums @@ -308,12 +324,15 @@ def test_no_rec_again(img_path, gt_td_nums, gt2): def test_no_ocr(img_path, html_output, points_len): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(img_path) - html, elasp, polygons, logic_points, ocr_res = table_recog( - str(img_path), ocr_result, need_ocr=False + ocr_result, _ = ocr_engine(str(img_path)) + table_results = table_recog(str(img_path), need_ocr=False) + table_html_str, table_cell_bboxes, table_logic_points = ( + table_results.pred_html, + table_results.cell_bboxes, + table_results.logic_points, ) - assert len(ocr_res) == 0 - assert len(polygons) > points_len - assert len(logic_points) > points_len - assert len(polygons) == len(logic_points) - assert html == html_output + + assert len(table_cell_bboxes) > points_len + assert len(table_logic_points) > points_len + assert len(table_cell_bboxes) == len(table_logic_points) + assert table_html_str == html_output From b844584c490089df189fe2e58014fe82cad77f31 Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sat, 8 Mar 2025 16:40:46 +0800 Subject: [PATCH 4/8] fix: fix wired unet model rec --- demo_all.py | 41 ++++++++++++++++++++++++++++++++ demo_lineless.py | 4 ++-- demo_wired.py | 4 ++-- lineless_table_rec/main.py | 21 +++++++++------- tests/test_lineless_table_rec.py | 4 ++-- tests/test_wired_table_rec.py | 4 ++-- wired_table_rec/main.py | 28 +++++++++++----------- 7 files changed, 76 insertions(+), 30 deletions(-) create mode 100644 demo_all.py diff --git a/demo_all.py b/demo_all.py new file mode 100644 index 0000000..8ee15d5 --- /dev/null +++ b/demo_all.py @@ -0,0 +1,41 @@ +from table_cls import TableCls +from wired_table_rec.main import WiredTableInput, WiredTableRecognition +from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition + +if __name__ == "__main__": + # Init + wired_input = WiredTableInput() + lineless_input = LinelessTableInput() + wired_engine = WiredTableRecognition(wired_input) + lineless_engine = LinelessTableRecognition(lineless_input) + # 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s) + table_cls = TableCls() + img_path = f"tests/test_files/table.jpg" + + cls, elasp = table_cls(img_path) + if cls == "wired": + table_engine = wired_engine + else: + table_engine = lineless_engine + + table_results = table_engine(img_path, enhance_box_line=False) + # 使用RapidOCR输入 + # ocr_engine = RapidOCR() + # ocr_result, _ = ocr_engine(img_path) + # table_results = table_engine(img_path, ocr_result=ocr_result) + + # Visualize table rec result + # save_dir = Path("outputs") + # save_dir.mkdir(parents=True, exist_ok=True) + # + # save_html_path = f"outputs/{Path(img_path).stem}.html" + # save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + # save_logic_path = ( + # f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" + # ) + + # + # vis_table = VisTable() + # vis_imged = vis_table( + # img_path, table_results, save_html_path, save_drawed_path, save_logic_path + # ) diff --git a/demo_lineless.py b/demo_lineless.py index cb8ddef..c1519e7 100644 --- a/demo_lineless.py +++ b/demo_lineless.py @@ -6,12 +6,12 @@ from rapidocr_onnxruntime import RapidOCR from lineless_table_rec import LinelessTableRecognition -from lineless_table_rec.main import RapidTableInput +from lineless_table_rec.main import LinelessTableInput from lineless_table_rec.utils.utils import VisTable output_dir = Path("outputs") output_dir.mkdir(parents=True, exist_ok=True) -input_args = RapidTableInput() +input_args = LinelessTableInput() table_engine = LinelessTableRecognition(input_args) ocr_engine = RapidOCR() viser = VisTable() diff --git a/demo_wired.py b/demo_wired.py index 385ce39..65d2eb7 100644 --- a/demo_wired.py +++ b/demo_wired.py @@ -6,12 +6,12 @@ from rapidocr_onnxruntime import RapidOCR from wired_table_rec import WiredTableRecognition -from wired_table_rec.main import RapidTableInput, ModelType +from wired_table_rec.main import WiredTableInput from wired_table_rec.utils.utils import VisTable output_dir = Path("outputs") output_dir.mkdir(parents=True, exist_ok=True) -input_args = RapidTableInput(model_type=ModelType.CYCLE_CENTER_NET.value) +input_args = WiredTableInput() table_engine = WiredTableRecognition(input_args) ocr_engine = RapidOCR() viser = VisTable() diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index ef4261c..9b4a705 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -24,6 +24,7 @@ match_ocr_cell, plot_html_table, sorted_ocr_boxes, + box_4_1_poly_to_box_4_2, ) @@ -41,7 +42,7 @@ class ModelType(Enum): @dataclass -class RapidTableInput: +class LinelessTableInput: model_type: Optional[str] = ModelType.LORE.value model_path: Union[str, Path, None, Dict[str, str]] = None use_cuda: bool = False @@ -49,7 +50,7 @@ class RapidTableInput: @dataclass -class RapidTableOutput: +class LinelessTableOutput: pred_html: Optional[str] = None cell_bboxes: Optional[np.ndarray] = None logic_points: Optional[np.ndarray] = None @@ -57,7 +58,7 @@ class RapidTableOutput: class LinelessTableRecognition: - def __init__(self, config: RapidTableInput): + def __init__(self, config: LinelessTableInput): self.model_type = config.model_type if self.model_type not in KEY_TO_MODEL_URL: model_list = ",".join(KEY_TO_MODEL_URL) @@ -78,7 +79,7 @@ def __call__( content: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, **kwargs, - ) -> RapidTableOutput: + ) -> LinelessTableOutput: s = time.perf_counter() rec_again = True need_ocr = True @@ -92,7 +93,7 @@ def __call__( sorted_polygons, idx_list = sorted_ocr_boxes( [box_4_2_poly_to_box_4_1(box) for box in polygons] ) - return RapidTableOutput( + return LinelessTableOutput( "", sorted_polygons, logi_points[idx_list], @@ -121,6 +122,10 @@ def __call__( # 将同一个识别框中的ocr结果排序并同行合并 t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list) # 渲染为html + polygons = [ + box_4_1_poly_to_box_4_2(t_box_ocr["t_box"]) + for t_box_ocr in t_rec_ocr_list + ] logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list] cell_box_det_map = { i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]] @@ -132,13 +137,13 @@ def __call__( _, idx_list = sorted_ocr_boxes( [t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list] ) - polygons = polygons.reshape(-1, 8) + polygons = np.array(polygons).reshape(-1, 8) logi_points = np.array(logi_points) elapse = time.perf_counter() - s except Exception: logging.warning(traceback.format_exc()) - return RapidTableOutput("", None, None, 0.0) - return RapidTableOutput(pred_html, polygons, logi_points, elapse) + return LinelessTableOutput("", None, None, 0.0) + return LinelessTableOutput(pred_html, polygons, logi_points, elapse) def transform_res( self, diff --git a/tests/test_lineless_table_rec.py b/tests/test_lineless_table_rec.py index 3eb42d7..e796f94 100644 --- a/tests/test_lineless_table_rec.py +++ b/tests/test_lineless_table_rec.py @@ -5,7 +5,7 @@ from pathlib import Path import pytest -from lineless_table_rec.main import RapidTableInput, ModelType +from lineless_table_rec.main import LinelessTableInput, ModelType cur_dir = Path(__file__).resolve().parent root_dir = cur_dir.parent @@ -16,7 +16,7 @@ from lineless_table_rec import LinelessTableRecognition test_file_dir = cur_dir / "test_files" -input_args = RapidTableInput(model_type=ModelType.LORE.value) +input_args = LinelessTableInput(model_type=ModelType.LORE.value) table_recog = LinelessTableRecognition(input_args) diff --git a/tests/test_wired_table_rec.py b/tests/test_wired_table_rec.py index 55d8818..44c3615 100644 --- a/tests/test_wired_table_rec.py +++ b/tests/test_wired_table_rec.py @@ -8,7 +8,7 @@ from bs4 import BeautifulSoup from rapidocr_onnxruntime import RapidOCR -from wired_table_rec.main import RapidTableInput, ModelType +from wired_table_rec.main import WiredTableInput, ModelType from wired_table_rec.utils.utils import rescale_size from wired_table_rec.utils.utils_table_recover import ( plot_html_table, @@ -26,7 +26,7 @@ from wired_table_rec import WiredTableRecognition test_file_dir = cur_dir / "test_files" / "wired" -input_args = RapidTableInput(model_type=ModelType.UNET.value) +input_args = WiredTableInput(model_type=ModelType.UNET.value) table_recog = WiredTableRecognition(input_args) ocr_engine = RapidOCR() diff --git a/wired_table_rec/main.py b/wired_table_rec/main.py index b470525..2afa8c3 100644 --- a/wired_table_rec/main.py +++ b/wired_table_rec/main.py @@ -41,7 +41,7 @@ class ModelType(Enum): @dataclass -class RapidTableInput: +class WiredTableInput: model_type: Optional[str] = ModelType.UNET.value model_path: Union[str, Path, None, Dict[str, str]] = None use_cuda: bool = False @@ -49,7 +49,7 @@ class RapidTableInput: @dataclass -class RapidTableOutput: +class WiredTableOutput: pred_html: Optional[str] = None cell_bboxes: Optional[np.ndarray] = None logic_points: Optional[np.ndarray] = None @@ -57,7 +57,7 @@ class RapidTableOutput: class WiredTableRecognition: - def __init__(self, config: RapidTableInput): + def __init__(self, config: WiredTableInput): self.model_type = config.model_type if self.model_type not in KEY_TO_MODEL_URL: model_list = ",".join(KEY_TO_MODEL_URL) @@ -85,7 +85,7 @@ def __call__( img: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, **kwargs, - ) -> RapidTableOutput: + ) -> WiredTableOutput: s = time.perf_counter() rec_again = True need_ocr = True @@ -100,7 +100,7 @@ def __call__( polygons, rotated_polygons = self.table_structure(img, **kwargs) if polygons is None: logging.warning("polygons is None.") - return RapidTableOutput("", None, None, 0.0) + return WiredTableOutput("", None, None, 0.0) try: table_res, logi_points = self.table_recover( @@ -115,7 +115,7 @@ def __call__( sorted_polygons, idx_list = sorted_ocr_boxes( [box_4_2_poly_to_box_4_1(box) for box in polygons] ) - return RapidTableOutput( + return WiredTableOutput( "", sorted_polygons, logi_points[idx_list], @@ -137,14 +137,14 @@ def __call__( for i, t_box_ocr in enumerate(t_rec_ocr_list) } pred_html = plot_html_table(logi_points, cell_box_det_map) - polygons = polygons.reshape(-1, 8) + polygons = np.array(polygons).reshape(-1, 8) logi_points = np.array(logi_points) elapse = time.perf_counter() - s except Exception: logging.warning(traceback.format_exc()) - return RapidTableOutput("", None, None, 0.0) - return RapidTableOutput(pred_html, polygons, logi_points, elapse) + return WiredTableOutput("", None, None, 0.0) + return WiredTableOutput(pred_html, polygons, logi_points, elapse) def transform_res( self, @@ -276,12 +276,12 @@ def main(): raise ModuleNotFoundError( "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." ) from exc - - table_rec = WiredTableRecognition() + input_args = WiredTableInput() + table_rec = WiredTableRecognition(input_args) ocr_result, _ = ocr_engine(args.img_path) - table_str, elapse = table_rec(args.img_path, ocr_result) - print(table_str) - print(f"cost: {elapse:.5f}") + table_results = table_rec(args.img_path, ocr_result) + print(table_results.pred_html) + print(f"cost: {table_results.elapse:.5f}") if __name__ == "__main__": From 49ac5134ec5281d4fa393126b128e55b37e20eeb Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sat, 8 Mar 2025 16:43:18 +0800 Subject: [PATCH 5/8] chore: add readme & change workflow --- .github/workflows/lineless_table_rec.yml | 8 -- .github/workflows/table_cls.yml | 8 -- .github/workflows/wired_table_rec.yml | 10 -- README.md | 121 +++++++++++++++-------- README_en.md | 111 ++++++++++++++------- setup_lineless.py | 3 +- setup_table_cls.py | 3 +- setup_wired.py | 3 +- 8 files changed, 158 insertions(+), 109 deletions(-) diff --git a/.github/workflows/lineless_table_rec.yml b/.github/workflows/lineless_table_rec.yml index 9cdba75..8a93fe3 100644 --- a/.github/workflows/lineless_table_rec.yml +++ b/.github/workflows/lineless_table_rec.yml @@ -30,10 +30,6 @@ jobs: pip install -r requirements.txt pip install pytest - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/lineless_table_rec_models.zip - unzip lineless_table_rec_models.zip - mv lineless_table_rec_models/*.onnx lineless_table_rec/models/ - pytest tests/test_lineless_table_rec.py GenerateWHL_PushPyPi: @@ -55,10 +51,6 @@ jobs: python -m pip install --upgrade pip pip install wheel get_pypi_latest_version - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/lineless_table_rec_models.zip - unzip lineless_table_rec_models.zip - mv lineless_table_rec_models/*.onnx lineless_table_rec/models/ - python setup_lineless.py bdist_wheel "${{ github.ref_name }}" # - name: Publish distribution 📦 to Test PyPI diff --git a/.github/workflows/table_cls.yml b/.github/workflows/table_cls.yml index 797d8cf..17ad03c 100644 --- a/.github/workflows/table_cls.yml +++ b/.github/workflows/table_cls.yml @@ -29,10 +29,6 @@ jobs: pip install -r requirements.txt pip install pytest beautifulsoup4 - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip - unzip table_cls_models.zip - mv table_cls_models/*.onnx table_cls/models/ - pytest tests/test_table_cls.py GenerateWHL_PushPyPi: @@ -54,10 +50,6 @@ jobs: python -m pip install --upgrade pip pip install wheel get_pypi_latest_version - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip - unzip table_cls_models.zip - mv table_cls_models/*.onnx table_cls/models/ - python setup_table_cls.py bdist_wheel "${{ github.ref_name }}" - name: Publish distribution 📦 to PyPI diff --git a/.github/workflows/wired_table_rec.yml b/.github/workflows/wired_table_rec.yml index fc65e1b..e14f7aa 100644 --- a/.github/workflows/wired_table_rec.yml +++ b/.github/workflows/wired_table_rec.yml @@ -28,11 +28,6 @@ jobs: run: | pip install -r requirements.txt pip install pytest beautifulsoup4 - - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/wired_table_rec_models.zip - unzip wired_table_rec_models.zip - mv wired_table_rec_models/*.onnx wired_table_rec/models/ - pytest tests/test_wired_table_rec.py GenerateWHL_PushPyPi: @@ -53,11 +48,6 @@ jobs: pip install -r requirements.txt python -m pip install --upgrade pip pip install wheel get_pypi_latest_version - - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/wired_table_rec_models.zip - unzip wired_table_rec_models.zip - mv wired_table_rec_models/*.onnx wired_table_rec/models/ - python setup_wired.py bdist_wheel "${{ github.ref_name }}" - name: Publish distribution 📦 to PyPI diff --git a/README.md b/README.md index 2018db3..a773ec7 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,14 @@ ### 最近更新 -- **2024.11.22** - - 支持单字符匹配方案,需要RapidOCR>=1.4.0 - **2024.12.25** - 补充文档扭曲矫正/去模糊/去阴影/二值化方案,可作为前置处理 [RapidUnDistort](https://github.com/Joker1212/RapidUnWrap) - **2025.1.9** - - RapidTable支持了 unitable 模型,精度更高支持torch推理,补充测评数据 + - RapidTable支持了 unitable 模型,精度更高支持torch推理,补充测评数据 +- **2025.3.9** + - 输入输出格式对齐RapidTable + - 支持模型自动下载 + - 增加来自paddle的新表格分类模型 ### 简介 💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。\ @@ -81,55 +83,63 @@ pip install wired_table_rec lineless_table_rec table_cls ``` ### 快速使用 - +> ⚠️注意:在`wired_table_rec/table_cls`>=1.2.0` `lineless_table_rec` > 0.1.0 后,采用同RapidTable完全一致格式的输入输出 ``` python {linenos=table} -import os +from pathlib import Path -from lineless_table_rec import LinelessTableRecognition -from lineless_table_rec.utils_table_recover import format_html, plot_rec_box_with_logic_info, plot_rec_box +from wired_table_rec.utils.utils import VisTable from table_cls import TableCls -from wired_table_rec import WiredTableRecognition -from rapidocr_onnxruntime import RapidOCR - -lineless_engine = LinelessTableRecognition() -wired_engine = WiredTableRecognition() -# 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型 -table_cls = TableCls() # TableCls(model_type="yolox"),TableCls(model_type="q") -img_path = f'images/img14.jpg' +from wired_table_rec.main import WiredTableInput, WiredTableRecognition +from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition +from rapidocr_onnxruntime import RapidOCR, VisRes + +# 初始化引擎 +wired_input = WiredTableInput() +lineless_input = LinelessTableInput() +wired_engine = WiredTableRecognition(wired_input) +lineless_engine = LinelessTableRecognition(lineless_input) +# 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s) +table_cls = TableCls() +img_path = f'tests/test_files/table.jpg' cls,elasp = table_cls(img_path) if cls == 'wired': table_engine = wired_engine else: table_engine = lineless_engine - -html, elasp, polygons, logic_points, ocr_res = table_engine(img_path) -print(f"elasp: {elasp}") - -# 使用其他ocr模型 -#ocr_engine =RapidOCR(det_model_path="xxx/det_server_infer.onnx",rec_model_path="xxx/rec_server_infer.onnx") -#ocr_res, _ = ocr_engine(img_path) -#html, elasp, polygons, logic_points, ocr_res = table_engine(img_path, ocr_result=ocr_res) -# output_dir = f'outputs' -# complete_html = format_html(html) -# os.makedirs(os.path.dirname(f"{output_dir}/table.html"), exist_ok=True) -# with open(f"{output_dir}/table.html", "w", encoding="utf-8") as file: -# file.write(complete_html) -# # 可视化表格识别框 + 逻辑行列信息 -# plot_rec_box_with_logic_info( -# img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons + +table_results = table_engine(img_path, enhance_box_line=False) +# 使用RapidOCR输入 +# ocr_engine = RapidOCR() +# ocr_result, _ = ocr_engine(img_path) +# table_results = table_engine(img_path, ocr_result=ocr_result) + +# 可视化并存储结果,包含识别框+行列坐标 +# save_dir = Path("outputs") +# save_dir.mkdir(parents=True, exist_ok=True) +# +# save_html_path = f"outputs/{Path(img_path).stem}.html" +# save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" +# save_logic_path = ( +# f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" # ) -# # 可视化 ocr 识别框 -# plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) +# +# vis_table = VisTable() +# vis_imged = vis_table( +# img_path, table_results, save_html_path, save_drawed_path, save_logic_path +# ) + ``` #### 单字ocr匹配 + ```python # 将单字box转换为行识别同样的结构) from rapidocr_onnxruntime import RapidOCR -from wired_table_rec.utils_table_recover import trans_char_ocr_res +from wired_table_rec.utils.utils_table_recover import trans_char_ocr_res + img_path = "tests/test_files/wired/table4.jpg" -ocr_engine =RapidOCR() +ocr_engine = RapidOCR() ocr_res, _ = ocr_engine(img_path, return_word_box=True) ocr_res = trans_char_ocr_res(ocr_res) ``` @@ -177,11 +187,42 @@ for i, res in enumerate(result): ### 核心参数 ```python -wired_table_rec = WiredTableRecognition() -html, elasp, polygons, logic_points, ocr_res = wired_table_rec( +# 输入(WiredTableInput/LinelessTableInput) +@dataclass +class WiredTableInput: + model_type: Optional[str] = "unet" #unet/cycle_center_net + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + +@dataclass +class LinelessTableInput: + model_type: Optional[str] = "lore" #lore + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + +# 输出(WiredTableOutput/LinelessTableOutput) +@dataclass +class WiredTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None + +@dataclass +class LinelessTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None +``` + +```python +wired_table_rec = WiredTableRecognition(WiredTableInput()) +table_results = wired_table_rec( img, # 图片 Union[str, np.ndarray, bytes, Path, PIL.Image.Image] ocr_result, # 输入rapidOCR识别结果,不传默认使用内部rapidocr模型 - version="v2", #默认使用v2线框模型,切换阿里读光模型可改为v1 enhance_box_line=True, # 识别框切割增强(关闭避免多余切割,开启减少漏切割),默认为True col_threshold=15, # 识别框左边界x坐标差值小于col_threshold的默认同列 row_threshold=10, # 识别框上边界y坐标差值小于row_threshold的默认同行 @@ -189,8 +230,8 @@ html, elasp, polygons, logic_points, ocr_res = wired_table_rec( need_ocr=True, # 是否进行OCR识别, 默认为True rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True ) -lineless_table_rec = LinelessTableRecognition() -html, elasp, polygons, logic_points, ocr_res = lineless_table_rec( +lineless_table_rec = LinelessTableRecognition(LinelessTableInput()) +table_results = lineless_table_rec( img, # 图片 Union[str, np.ndarray, bytes, Path, PIL.Image.Image] ocr_result, # 输入rapidOCR识别结果,不传默认使用内部rapidocr模型 need_ocr=True, # 是否进行OCR识别, 默认为True diff --git a/README_en.md b/README_en.md index 6613f4e..a0297e4 100644 --- a/README_en.md +++ b/README_en.md @@ -13,12 +13,14 @@ ### Recent Updates -- **2024.11.16** - - Added document distortion correction solution, which can be used as a pre-processing step [RapidUnWrap](https://github.com/Joker1212/RapidUnWrap) -- **2024.11.22** - - Support Char Rec, RapidOCR>=1.4.0 - **2024.12.25** - Add document preprocessing solutions for distortion correction, deblurring, shadow removal, and binarization. [RapidUnDistort](https://github.com/Joker1212/RapidUnWrap) +- **2025.1.9** + - RapidTable now supports the Unitable model, Evaluation data has been added. +- **2025.3.9** + - Align input and output formats with RapidTable + - support automatic model downloading + - introduce a new table classification model from [PaddleOCR](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-rc/docs/module_usage/tutorials/ocr_modules/table_classification.en.md). ### Introduction 💖 This repository serves as an inference library for structured recognition of tables within documents, including models for wired and wireless table recognition from Alibaba DulaLight, a wired table model from llaipython (WeChat), and a built-in table classification model from NetEase Qanything. @@ -79,55 +81,62 @@ pip install wired_table_rec lineless_table_rec table_cls ``` ### Quick start +> ⚠️:`wired_table_rec/table_cls`>=1.2.0` `lineless_table_rec` > 0.1.0 ,the input and output format are same with `RapidTable` ``` python {linenos=table} -import os +from pathlib import Path -from lineless_table_rec import LinelessTableRecognition -from lineless_table_rec.utils_table_recover import format_html, plot_rec_box_with_logic_info, plot_rec_box +from wired_table_rec.utils.utils import VisTable from table_cls import TableCls -from wired_table_rec import WiredTableRecognition -from rapidocr_onnxruntime import RapidOCR - -lineless_engine = LinelessTableRecognition() -wired_engine = WiredTableRecognition() -# Default small YOLO model (0.1s), can switch to higher precision YOLOX (0.25s), or faster QAnything (0.07s) model -table_cls = TableCls() # TableCls(model_type="yolox"),TableCls(model_type="q") -img_path = f'images/img14.jpg' +from wired_table_rec.main import WiredTableInput, WiredTableRecognition +from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition +from rapidocr_onnxruntime import RapidOCR, VisRes + +# init engine +wired_input = WiredTableInput() +lineless_input = LinelessTableInput() +wired_engine = WiredTableRecognition(wired_input) +lineless_engine = LinelessTableRecognition(lineless_input) +#The default model is a small YOLO model (0.1s inference time), which can be switched to higher-precision YOLOX (0.25s), faster QAnything (0.07s), or PaddlePaddle models (0.03s). +table_cls = TableCls() +img_path = f'tests/test_files/table.jpg' cls,elasp = table_cls(img_path) if cls == 'wired': table_engine = wired_engine else: table_engine = lineless_engine - -html, elasp, polygons, logic_points, ocr_res = table_engine(img_path) -print(f"elasp: {elasp}") - -# Use other OCR models -#ocr_engine =RapidOCR(det_model_path="xxx/det_server_infer.onnx",rec_model_path="xxx/rec_server_infer.onnx") -#ocr_res, _ = ocr_engine(img_path) -#html, elasp, polygons, logic_points, ocr_res = table_engine(img_path, ocr_result=ocr_res) - -# output_dir = f'outputs' -# complete_html = format_html(html) -# os.makedirs(os.path.dirname(f"{output_dir}/table.html"), exist_ok=True) -# with open(f"{output_dir}/table.html", "w", encoding="utf-8") as file: -# file.write(complete_html) -# Visualize table recognition boxes + logical row and column information -# plot_rec_box_with_logic_info( -# img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons + +table_results = table_engine(img_path, enhance_box_line=False) +# use rapidOCR for as input +# ocr_engine = RapidOCR() +# ocr_result, _ = ocr_engine(img_path) +# table_results = table_engine(img_path, ocr_result=ocr_result) + +# Visualize and store the results, including detection bounding boxes and row/column coordinates. +# save_dir = Path("outputs") +# save_dir.mkdir(parents=True, exist_ok=True) +# +# save_html_path = f"outputs/{Path(img_path).stem}.html" +# save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" +# save_logic_path = ( +# f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" +# ) +# +# vis_table = VisTable() +# vis_imged = vis_table( +# img_path, table_results, save_html_path, save_drawed_path, save_logic_path # ) -# Visualize OCR recognition boxes -# plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) ``` #### Single Character OCR Matching + ```python # Convert single character boxes to the same structure as line recognition from rapidocr_onnxruntime import RapidOCR -from wired_table_rec.utils_table_recover import trans_char_ocr_res +from wired_table_rec.utils.utils_table_recover import trans_char_ocr_res + img_path = "tests/test_files/wired/table4.jpg" -ocr_engine =RapidOCR() +ocr_engine = RapidOCR() ocr_res, _ = ocr_engine(img_path, return_word_box=True) ocr_res = trans_char_ocr_res(ocr_res) ``` @@ -174,11 +183,39 @@ for i, res in enumerate(result): ### Core Parameters ```python +@dataclass +class WiredTableInput: + model_type: Optional[str] = "unet" #unet/cycle_center_net + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + +@dataclass +class LinelessTableInput: + model_type: Optional[str] = "lore" #lore + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + +@dataclass +class WiredTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None + +@dataclass +class LinelessTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None +``` +```python wired_table_rec = WiredTableRecognition() html, elasp, polygons, logic_points, ocr_res = wired_table_rec( img, # Image Union[str, np.ndarray, bytes, Path, PIL.Image.Image] ocr_result, # Input rapidOCR recognition result, use internal rapidocr model by default if not provided - version="v2", # Default to using v2 line model, switch to AliDamo model by changing to v1 enhance_box_line=True, # Enhance box line find (turn off to avoid excessive cutting, turn on to reduce missed cuts), default is True need_ocr=True, # Whether to perform OCR recognition, default is True rec_again=True, # Whether to re-recognize table boxes without detected text by cropping them separately, default is True diff --git a/setup_lineless.py b/setup_lineless.py index 5abea37..a712ecd 100644 --- a/setup_lineless.py +++ b/setup_lineless.py @@ -53,8 +53,7 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: license="Apache-2.0", install_requires=read_txt("requirements.txt"), include_package_data=True, - packages=[MODULE_NAME, f"{MODULE_NAME}.models"], - package_data={"": ["*.onnx"]}, + packages=[MODULE_NAME], keywords=["tsr,ocr,table-recognition"], classifiers=[ "Programming Language :: Python :: 3.6", diff --git a/setup_table_cls.py b/setup_table_cls.py index edcb56d..9cf4e18 100644 --- a/setup_table_cls.py +++ b/setup_table_cls.py @@ -46,8 +46,7 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: license="Apache-2.0", install_requires=read_txt("requirements.txt"), include_package_data=True, - packages=[MODULE_NAME, f"{MODULE_NAME}.models"], - package_data={"": ["*.onnx"]}, + packages=[MODULE_NAME], keywords=["table-classifier", "wired", "wireless", "table-recognition"], classifiers=[ "Programming Language :: Python :: 3.6", diff --git a/setup_wired.py b/setup_wired.py index eb4a127..856bc58 100644 --- a/setup_wired.py +++ b/setup_wired.py @@ -53,8 +53,7 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: license="Apache-2.0", install_requires=read_txt("requirements.txt"), include_package_data=True, - packages=[MODULE_NAME, f"{MODULE_NAME}.models"], - package_data={"": ["*.onnx"]}, + packages=[MODULE_NAME], keywords=["tsr,ocr,table-recognition"], classifiers=[ "Programming Language :: Python :: 3.6", From f8ed5f5514880b024526e598a37b4f60f7438514 Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Wed, 19 Mar 2025 22:03:31 +0800 Subject: [PATCH 6/8] chore: add teds score & fix setup package --- README.md | 26 ++++++++++++++------------ README_en.md | 3 +++ setup_lineless.py | 2 +- setup_table_cls.py | 2 +- setup_wired.py | 2 +- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index a773ec7..1a142e9 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ - 输入输出格式对齐RapidTable - 支持模型自动下载 - 增加来自paddle的新表格分类模型 + - 增加最新PaddleX表格识别模型测评值 ### 简介 💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。\ @@ -56,18 +57,19 @@ Surya-Tabled 使用内置ocr模块,表格模型为行列识别模型,无法识别单元格合并,导致分数较低 | 方法 | TEDS | TEDS-only-structure | -|:---------------------------------------------------------------------------------------------------------|:-----------:|:-------------------:| -| [surya-tabled(--skip-detect)](https://github.com/VikParuchuri/tabled) | 0.33437 | 0.65865 | -| [surya-tabled](https://github.com/VikParuchuri/tabled) | 0.33940 | 0.67103 | -| [deepdoctection(table-transformer)](https://github.com/deepdoctection/deepdoctection?tab=readme-ov-file) | 0.59975 | 0.69918 | -| [ppstructure_table_master](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.61606 | 0.73892 | -| [ppsturcture_table_engine](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.67924 | 0.78653 | -| [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) | 0.67310 | 0.81210 | -| [RapidTable(SLANet)](https://github.com/RapidAI/RapidTable) | 0.71654 | 0.81067 | -| table_cls + wired_table_rec v1 + lineless_table_rec | 0.75288 | 0.82574 | -| table_cls + wired_table_rec v2 + lineless_table_rec | 0.77676 | 0.84580 | -| [RapidTable(SLANet-plus)](https://github.com/RapidAI/RapidTable) | 0.84481 | 0.91369 | -| [RapidTable(unitable)](https://github.com/RapidAI/RapidTable) | **0.86200** | **0.91813** | +|:---------------------------------------------------------------------------------------------------------|:-----------:|:-----------------:| +| [surya-tabled(--skip-detect)](https://github.com/VikParuchuri/tabled) | 0.33437 | 0.65865 | +| [surya-tabled](https://github.com/VikParuchuri/tabled) | 0.33940 | 0.67103 | +| [deepdoctection(table-transformer)](https://github.com/deepdoctection/deepdoctection?tab=readme-ov-file) | 0.59975 | 0.69918 | +| [ppstructure_table_master](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.61606 | 0.73892 | +| [ppsturcture_table_engine](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.67924 | 0.78653 | +| [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) | 0.67310 | 0.81210 | +| [RapidTable(SLANet)](https://github.com/RapidAI/RapidTable) | 0.71654 | 0.81067 | +| table_cls + wired_table_rec v1 + lineless_table_rec | 0.75288 | 0.82574 | +| table_cls + wired_table_rec v2 + lineless_table_rec | 0.77676 | 0.84580 | +| [PaddleX(SLANetXt+RT-DERT)](https://github.com/PaddlePaddle/PaddleX) | 0.79900 | **0.92222** | +| [RapidTable(SLANet-plus)](https://github.com/RapidAI/RapidTable) | 0.84481 | 0.91369 | +| [RapidTable(unitable)](https://github.com/RapidAI/RapidTable) | **0.86200** | 0.91813 | ### 使用建议 wired_table_rec_v2(有线表格精度最高): 通用场景有线表格(论文,杂志,期刊, 收据,单据,账单) diff --git a/README_en.md b/README_en.md index a0297e4..b9f0084 100644 --- a/README_en.md +++ b/README_en.md @@ -67,7 +67,10 @@ Surya-Tabled uses its built-in OCR module, which is a row-column recognition mod | [RapidTable(SLANet)](https://github.com/RapidAI/RapidTable) | 0.71654 | 0.81067 | | table_cls + wired_table_rec v1 + lineless_table_rec | 0.75288 | 0.82574 | | table_cls + wired_table_rec v2 + lineless_table_rec | 0.77676 | 0.84580 | +| [PaddleX(SLANetXt+RT-DERT)](https://github.com/PaddlePaddle/PaddleX) | 0.79900 | **0.92222** | | [RapidTable(SLANet-plus)](https://github.com/RapidAI/RapidTable) | **0.84481** | **0.91369** | +| [RapidTable(unitable)](https://github.com/RapidAI/RapidTable) | **0.86200** | 0.91813 | + ### Usage Recommendations wired_table_rec_v2 (highest precision for wired tables): General scenes for wired tables (papers, magazines, journals, receipts, invoices, bills) diff --git a/setup_lineless.py b/setup_lineless.py index a712ecd..87694a5 100644 --- a/setup_lineless.py +++ b/setup_lineless.py @@ -53,7 +53,7 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: license="Apache-2.0", install_requires=read_txt("requirements.txt"), include_package_data=True, - packages=[MODULE_NAME], + packages=[MODULE_NAME, f"{MODULE_NAME}.utils"], keywords=["tsr,ocr,table-recognition"], classifiers=[ "Programming Language :: Python :: 3.6", diff --git a/setup_table_cls.py b/setup_table_cls.py index 9cf4e18..6169008 100644 --- a/setup_table_cls.py +++ b/setup_table_cls.py @@ -46,7 +46,7 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: license="Apache-2.0", install_requires=read_txt("requirements.txt"), include_package_data=True, - packages=[MODULE_NAME], + packages=[MODULE_NAME, f"{MODULE_NAME}.utils"], keywords=["table-classifier", "wired", "wireless", "table-recognition"], classifiers=[ "Programming Language :: Python :: 3.6", diff --git a/setup_wired.py b/setup_wired.py index 856bc58..271f6a5 100644 --- a/setup_wired.py +++ b/setup_wired.py @@ -53,7 +53,7 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: license="Apache-2.0", install_requires=read_txt("requirements.txt"), include_package_data=True, - packages=[MODULE_NAME], + packages=[MODULE_NAME, f"{MODULE_NAME}.utils"], keywords=["tsr,ocr,table-recognition"], classifiers=[ "Programming Language :: Python :: 3.6", From 05d10ee98c7c8d08380ec16c941e4c1be1cbec38 Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sun, 30 Mar 2025 12:20:01 +0800 Subject: [PATCH 7/8] feat: sup for rapidOCR 2.0 --- .github/workflows/lineless_table_rec.yml | 4 +- .github/workflows/wired_table_rec.yml | 2 + README.md | 113 +++++++++++++---------- README_en.md | 108 +++++++++++++--------- demo_all.py | 30 ++++-- demo_lineless.py | 14 ++- demo_wired.py | 14 ++- lineless_table_rec/main.py | 35 ++----- requirements.txt | 1 - tests/test_lineless_table_rec.py | 16 +++- tests/test_wired_table_rec.py | 33 +++++-- wired_table_rec/main.py | 32 ++----- 12 files changed, 221 insertions(+), 181 deletions(-) diff --git a/.github/workflows/lineless_table_rec.yml b/.github/workflows/lineless_table_rec.yml index 8a93fe3..b36e266 100644 --- a/.github/workflows/lineless_table_rec.yml +++ b/.github/workflows/lineless_table_rec.yml @@ -29,7 +29,7 @@ jobs: run: | pip install -r requirements.txt pip install pytest - + pip install rapidocr pytest tests/test_lineless_table_rec.py GenerateWHL_PushPyPi: @@ -50,7 +50,7 @@ jobs: pip install -r requirements.txt python -m pip install --upgrade pip pip install wheel get_pypi_latest_version - + pip install rapidocr python setup_lineless.py bdist_wheel "${{ github.ref_name }}" # - name: Publish distribution 📦 to Test PyPI diff --git a/.github/workflows/wired_table_rec.yml b/.github/workflows/wired_table_rec.yml index e14f7aa..c4bbce1 100644 --- a/.github/workflows/wired_table_rec.yml +++ b/.github/workflows/wired_table_rec.yml @@ -28,6 +28,7 @@ jobs: run: | pip install -r requirements.txt pip install pytest beautifulsoup4 + pip install rapidocr pytest tests/test_wired_table_rec.py GenerateWHL_PushPyPi: @@ -48,6 +49,7 @@ jobs: pip install -r requirements.txt python -m pip install --upgrade pip pip install wheel get_pypi_latest_version + pip install rapidocr python setup_wired.py bdist_wheel "${{ github.ref_name }}" - name: Publish distribution 📦 to PyPI diff --git a/README.md b/README.md index 1a142e9..5e9fc71 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,12 @@ - 补充文档扭曲矫正/去模糊/去阴影/二值化方案,可作为前置处理 [RapidUnDistort](https://github.com/Joker1212/RapidUnWrap) - **2025.1.9** - RapidTable支持了 unitable 模型,精度更高支持torch推理,补充测评数据 -- **2025.3.9** +- **2025.3.30** - 输入输出格式对齐RapidTable - 支持模型自动下载 - 增加来自paddle的新表格分类模型 - 增加最新PaddleX表格识别模型测评值 + - 支持 rapidocr 2.0 取消重复ocr检测 ### 简介 💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。\ @@ -79,9 +80,10 @@ wired_table_rec_v2 对1500px内大小的图片效果最好,所以分辨率超 SLANet-plus/unitable (综合精度最高): 文档场景表格(论文,杂志,期刊中的表格) ### 安装 - +rapidocr2.0以上版本支持torch,onnx,paddle,openvino等多引擎切换,详情参考[rapidocr文档](https://rapidai.github.io/RapidOCRDocs/main/install_usage/rapidocr/usage/) ``` python {linenos=table} pip install wired_table_rec lineless_table_rec table_cls +pip install rapidocr ``` ### 快速使用 @@ -89,47 +91,63 @@ pip install wired_table_rec lineless_table_rec table_cls ``` python {linenos=table} from pathlib import Path -from wired_table_rec.utils.utils import VisTable +from demo_wired import viser from table_cls import TableCls from wired_table_rec.main import WiredTableInput, WiredTableRecognition from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition -from rapidocr_onnxruntime import RapidOCR, VisRes - -# 初始化引擎 -wired_input = WiredTableInput() -lineless_input = LinelessTableInput() -wired_engine = WiredTableRecognition(wired_input) -lineless_engine = LinelessTableRecognition(lineless_input) -# 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s) -table_cls = TableCls() -img_path = f'tests/test_files/table.jpg' - -cls,elasp = table_cls(img_path) -if cls == 'wired': - table_engine = wired_engine -else: - table_engine = lineless_engine - -table_results = table_engine(img_path, enhance_box_line=False) -# 使用RapidOCR输入 -# ocr_engine = RapidOCR() -# ocr_result, _ = ocr_engine(img_path) -# table_results = table_engine(img_path, ocr_result=ocr_result) - -# 可视化并存储结果,包含识别框+行列坐标 -# save_dir = Path("outputs") -# save_dir.mkdir(parents=True, exist_ok=True) -# -# save_html_path = f"outputs/{Path(img_path).stem}.html" -# save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" -# save_logic_path = ( -# f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" -# ) -# -# vis_table = VisTable() -# vis_imged = vis_table( -# img_path, table_results, save_html_path, save_drawed_path, save_logic_path -# ) +from rapidocr import RapidOCR + + +if __name__ == "__main__": + # Init + wired_input = WiredTableInput() + lineless_input = LinelessTableInput() + wired_engine = WiredTableRecognition(wired_input) + lineless_engine = LinelessTableRecognition(lineless_input) + # 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s) + table_cls = TableCls() + img_path = f"tests/test_files/table.jpg" + + cls, elasp = table_cls(img_path) + if cls == "wired": + table_engine = wired_engine + else: + table_engine = lineless_engine + + # 使用RapidOCR输入 + ocr_engine = RapidOCR() + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)) + table_results = table_engine( + img_path, ocr_result=ocr_result, enhance_box_line=False + ) + + + # 使用单字识别 + # word_results = rapid_ocr_output.word_results + # ocr_result = [ + # [word_result[2], word_result[0], word_result[1]] for word_result in word_results + # ] + # table_results = table_engine( + # img_path, ocr_result=ocr_result, enhance_box_line=False + # ) + + # Save + # save_dir = Path("outputs") + # save_dir.mkdir(parents=True, exist_ok=True) + # + # save_html_path = f"outputs/{Path(img_path).stem}.html" + # save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + # save_logic_path = ( + # f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" + # ) + + # Visualize table rec result + # vis_imged = viser( + # img_path, table_results, save_html_path, save_drawed_path, save_logic_path + # ) + + ``` @@ -137,13 +155,14 @@ table_results = table_engine(img_path, enhance_box_line=False) ```python # 将单字box转换为行识别同样的结构) -from rapidocr_onnxruntime import RapidOCR -from wired_table_rec.utils.utils_table_recover import trans_char_ocr_res - +from rapidocr import RapidOCR img_path = "tests/test_files/wired/table4.jpg" ocr_engine = RapidOCR() -ocr_res, _ = ocr_engine(img_path, return_word_box=True) -ocr_res = trans_char_ocr_res(ocr_res) +rapid_ocr_output = ocr_engine(img_path, return_word_box=True) +word_results = rapid_ocr_output.word_results +ocr_result = [ + [word_result[2], word_result[0], word_result[1]] for word_result in word_results +] ``` #### 表格旋转及透视修正 @@ -230,14 +249,12 @@ table_results = wired_table_rec( row_threshold=10, # 识别框上边界y坐标差值小于row_threshold的默认同行 rotated_fix=True, # wiredV2支持,轻度旋转(-45°~45°)矫正,默认为True need_ocr=True, # 是否进行OCR识别, 默认为True - rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True ) lineless_table_rec = LinelessTableRecognition(LinelessTableInput()) table_results = lineless_table_rec( img, # 图片 Union[str, np.ndarray, bytes, Path, PIL.Image.Image] ocr_result, # 输入rapidOCR识别结果,不传默认使用内部rapidocr模型 need_ocr=True, # 是否进行OCR识别, 默认为True - rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True ) ``` @@ -268,7 +285,7 @@ table_results = lineless_table_rec( ```mermaid flowchart TD A[/表格图片/] --> B([表格分类 table_cls]) - B --> C([有线表格识别 wired_table_rec]) & D([无线表格识别 lineless_table_rec]) --> E([文字识别 rapidocr_onnxruntime]) + B --> C([有线表格识别 wired_table_rec]) & D([无线表格识别 lineless_table_rec]) --> E([文字识别 rapidocr]) E --> F[/html结构化输出/] ``` diff --git a/README_en.md b/README_en.md index b9f0084..ce491fe 100644 --- a/README_en.md +++ b/README_en.md @@ -17,10 +17,11 @@ - Add document preprocessing solutions for distortion correction, deblurring, shadow removal, and binarization. [RapidUnDistort](https://github.com/Joker1212/RapidUnWrap) - **2025.1.9** - RapidTable now supports the Unitable model, Evaluation data has been added. -- **2025.3.9** +- **2025.3.30** - Align input and output formats with RapidTable - support automatic model downloading - introduce a new table classification model from [PaddleOCR](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-rc/docs/module_usage/tutorials/ocr_modules/table_classification.en.md). + - sup rapidocr2 ### Introduction 💖 This repository serves as an inference library for structured recognition of tables within documents, including models for wired and wireless table recognition from Alibaba DulaLight, a wired table model from llaipython (WeChat), and a built-in table classification model from NetEase Qanything. @@ -81,6 +82,7 @@ paddlex-SLANet-plus (highest overall precision): Document scene tables (tables i ```python pip install wired_table_rec lineless_table_rec table_cls +pip install rapidocr ``` ### Quick start @@ -89,59 +91,75 @@ pip install wired_table_rec lineless_table_rec table_cls ``` python {linenos=table} from pathlib import Path -from wired_table_rec.utils.utils import VisTable +from demo_wired import viser from table_cls import TableCls from wired_table_rec.main import WiredTableInput, WiredTableRecognition from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition -from rapidocr_onnxruntime import RapidOCR, VisRes - -# init engine -wired_input = WiredTableInput() -lineless_input = LinelessTableInput() -wired_engine = WiredTableRecognition(wired_input) -lineless_engine = LinelessTableRecognition(lineless_input) -#The default model is a small YOLO model (0.1s inference time), which can be switched to higher-precision YOLOX (0.25s), faster QAnything (0.07s), or PaddlePaddle models (0.03s). -table_cls = TableCls() -img_path = f'tests/test_files/table.jpg' - -cls,elasp = table_cls(img_path) -if cls == 'wired': - table_engine = wired_engine -else: - table_engine = lineless_engine - -table_results = table_engine(img_path, enhance_box_line=False) -# use rapidOCR for as input -# ocr_engine = RapidOCR() -# ocr_result, _ = ocr_engine(img_path) -# table_results = table_engine(img_path, ocr_result=ocr_result) - -# Visualize and store the results, including detection bounding boxes and row/column coordinates. -# save_dir = Path("outputs") -# save_dir.mkdir(parents=True, exist_ok=True) -# -# save_html_path = f"outputs/{Path(img_path).stem}.html" -# save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" -# save_logic_path = ( -# f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" -# ) -# -# vis_table = VisTable() -# vis_imged = vis_table( -# img_path, table_results, save_html_path, save_drawed_path, save_logic_path -# ) +from rapidocr import RapidOCR + + +if __name__ == "__main__": + # Init + wired_input = WiredTableInput() + lineless_input = LinelessTableInput() + wired_engine = WiredTableRecognition(wired_input) + lineless_engine = LinelessTableRecognition(lineless_input) + # yolo(0.1s),yolox(0.25s),qanything(0.07s) paddle(0.03s) + table_cls = TableCls() + img_path = f"tests/test_files/table.jpg" + + cls, elasp = table_cls(img_path) + if cls == "wired": + table_engine = wired_engine + else: + table_engine = lineless_engine + + # use rapid ocr as input + ocr_engine = RapidOCR() + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)) + table_results = table_engine( + img_path, ocr_result=ocr_result, enhance_box_line=False + ) + + + # use word rec ocr + # word_results = rapid_ocr_output.word_results + # ocr_result = [ + # [word_result[2], word_result[0], word_result[1]] for word_result in word_results + # ] + # table_results = table_engine( + # img_path, ocr_result=ocr_result, enhance_box_line=False + # ) + + # Save + # save_dir = Path("outputs") + # save_dir.mkdir(parents=True, exist_ok=True) + # + # save_html_path = f"outputs/{Path(img_path).stem}.html" + # save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + # save_logic_path = ( + # f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" + # ) + + # Visualize table rec result + # vis_imged = viser( + # img_path, table_results, save_html_path, save_drawed_path, save_logic_path + # ) + ``` #### Single Character OCR Matching ```python # Convert single character boxes to the same structure as line recognition -from rapidocr_onnxruntime import RapidOCR -from wired_table_rec.utils.utils_table_recover import trans_char_ocr_res - +from rapidocr import RapidOCR img_path = "tests/test_files/wired/table4.jpg" ocr_engine = RapidOCR() -ocr_res, _ = ocr_engine(img_path, return_word_box=True) -ocr_res = trans_char_ocr_res(ocr_res) +rapid_ocr_output = ocr_engine(img_path, return_word_box=True) +word_results = rapid_ocr_output.word_results +ocr_result = [ + [word_result[2], word_result[0], word_result[1]] for word_result in word_results +] ``` #### Table Rotation and Perspective Correction @@ -251,7 +269,7 @@ html, elasp, polygons, logic_points, ocr_res = lineless_table_rec( ```mermaid flowchart TD A[/table image/] --> B([table cls table_cls]) - B --> C([wired_table_rec]) & D([lineless_table_rec]) --> E([rapidocr_onnxruntime]) + B --> C([wired_table_rec]) & D([lineless_table_rec]) --> E([rapidocr]) E --> F[/html output/] ``` diff --git a/demo_all.py b/demo_all.py index 8ee15d5..653a947 100644 --- a/demo_all.py +++ b/demo_all.py @@ -1,6 +1,8 @@ from table_cls import TableCls from wired_table_rec.main import WiredTableInput, WiredTableRecognition from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition +from rapidocr import RapidOCR + if __name__ == "__main__": # Init @@ -18,13 +20,26 @@ else: table_engine = lineless_engine - table_results = table_engine(img_path, enhance_box_line=False) # 使用RapidOCR输入 - # ocr_engine = RapidOCR() - # ocr_result, _ = ocr_engine(img_path) - # table_results = table_engine(img_path, ocr_result=ocr_result) + ocr_engine = RapidOCR() + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_engine( + img_path, ocr_result=ocr_result, enhance_box_line=False + ) - # Visualize table rec result + # 使用单字识别 + # word_results = rapid_ocr_output.word_results + # ocr_result = [ + # [word_result[2], word_result[0], word_result[1]] for word_result in word_results + # ] + # table_results = table_engine( + # img_path, ocr_result=ocr_result, enhance_box_line=False + # ) + + # Save # save_dir = Path("outputs") # save_dir.mkdir(parents=True, exist_ok=True) # @@ -34,8 +49,7 @@ # f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" # ) - # - # vis_table = VisTable() - # vis_imged = vis_table( + # Visualize table rec result + # vis_imged = viser( # img_path, table_results, save_html_path, save_drawed_path, save_logic_path # ) diff --git a/demo_lineless.py b/demo_lineless.py index c1519e7..8ce1684 100644 --- a/demo_lineless.py +++ b/demo_lineless.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com from pathlib import Path -from rapidocr_onnxruntime import RapidOCR +from rapidocr import RapidOCR from lineless_table_rec import LinelessTableRecognition from lineless_table_rec.main import LinelessTableInput @@ -19,11 +19,17 @@ if __name__ == "__main__": img_path = "tests/test_files/lineless_table_recognition.jpg" - ocr_result, _ = ocr_engine(img_path) - boxes, txts, scores = list(zip(*ocr_result)) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + + # 使用单字识别 + # word_results = rapid_ocr_output.word_results + # ocr_result = [[word_result[2], word_result[0], word_result[1]] for word_result in word_results] # Table Rec - table_results = table_engine(img_path) + table_results = table_engine(img_path, ocr_result=ocr_result) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, diff --git a/demo_wired.py b/demo_wired.py index 65d2eb7..1db8037 100644 --- a/demo_wired.py +++ b/demo_wired.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com from pathlib import Path -from rapidocr_onnxruntime import RapidOCR +from rapidocr import RapidOCR from wired_table_rec import WiredTableRecognition from wired_table_rec.main import WiredTableInput @@ -18,11 +18,17 @@ if __name__ == "__main__": img_path = "tests/test_files/wired/bad_case_1.png" - ocr_result, _ = ocr_engine(img_path) - boxes, txts, scores = list(zip(*ocr_result)) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + + # 使用单字识别 + # word_results = rapid_ocr_output.word_results + # ocr_result = [[word_result[2], word_result[0], word_result[1]] for word_result in word_results] # Table Rec - table_results = table_engine(img_path) + table_results = table_engine(img_path, ocr_result) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index 9b4a705..0a313c2 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -1,16 +1,14 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import importlib import logging import time import traceback from dataclasses import dataclass, asdict from enum import Enum from pathlib import Path -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union, Optional, Any -import cv2 import numpy as np from .table_structure_lore import TSRLore @@ -20,7 +18,6 @@ box_4_2_poly_to_box_4_1, filter_duplicated_box, gather_ocr_list_by_row, - get_rotate_crop_image, match_ocr_cell, plot_html_table, sorted_ocr_boxes, @@ -69,10 +66,6 @@ def __init__(self, config: LinelessTableInput): config.model_path = self.get_model_path(config.model_type, config.model_path) self.table_structure = TSRLore(asdict(config)) self.load_img = LoadImage() - try: - self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError: - self.ocr = None def __call__( self, @@ -81,10 +74,8 @@ def __call__( **kwargs, ) -> LinelessTableOutput: s = time.perf_counter() - rec_again = True need_ocr = True if kwargs: - rec_again = kwargs.get("rec_again", True) need_ocr = kwargs.get("need_ocr", True) img = self.load_img(content) try: @@ -100,12 +91,10 @@ def __call__( time.perf_counter() - s, ) - if ocr_result is None and need_ocr: - ocr_result, _ = self.ocr(img) # ocr 结果匹配 cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_result, polygons) # 如果有识别框没有ocr结果,直接进行rec补充 - cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again) + cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map) # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理 t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points) # 拆分包含和重叠的识别框 @@ -248,29 +237,17 @@ def handle_overlap_row_col(self, res): res = [res[i] for i in range(len(res)) if i not in deleted_idx] return res, grid - def re_rec( + def fill_blank_rec( self, img: np.ndarray, sorted_polygons: np.ndarray, cell_box_map: Dict[int, List[str]], - rec_again=True, - ) -> Dict[int, List[any]]: + ) -> Dict[int, List[Any]]: """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" - # for i in range(sorted_polygons.shape[0]): if cell_box_map.get(i): continue - if not rec_again: - box = sorted_polygons[i] - cell_box_map[i] = [[box, "", 1]] - continue - crop_img = get_rotate_crop_image(img, sorted_polygons[i]) - pad_img = cv2.copyMakeBorder( - crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255) - ) - rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True) box = sorted_polygons[i] - text = [rec[0] for rec in rec_res] - scores = [rec[1] for rec in rec_res] - cell_box_map[i] = [[box, "".join(text), min(scores)]] + cell_box_map[i] = [[box, "", 1]] + continue return cell_box_map diff --git a/requirements.txt b/requirements.txt index 31b4439..5a38371 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,3 @@ opencv_python scipy scikit-image Shapely -rapidocr_onnxruntime>=1.3.8 diff --git a/tests/test_lineless_table_rec.py b/tests/test_lineless_table_rec.py index e796f94..036609f 100644 --- a/tests/test_lineless_table_rec.py +++ b/tests/test_lineless_table_rec.py @@ -4,6 +4,7 @@ import sys from pathlib import Path import pytest +from rapidocr import RapidOCR from lineless_table_rec.main import LinelessTableInput, ModelType @@ -18,6 +19,7 @@ test_file_dir = cur_dir / "test_files" input_args = LinelessTableInput(model_type=ModelType.LORE.value) table_recog = LinelessTableRecognition(input_args) +ocr_engine = RapidOCR() @pytest.mark.parametrize( @@ -29,8 +31,11 @@ ) def test_input_normal(img_path, table_str_len, td_nums): img_path = test_file_dir / img_path - - table_results = table_recog(str(img_path)) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_recog(str(img_path), ocr_result=ocr_result) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, @@ -259,8 +264,11 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html): ) def test_no_rec_again(img_path, table_str_len, td_nums): img_path = test_file_dir / img_path - - table_results = table_recog(str(img_path), rec_again=False) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_recog(str(img_path), ocr_result=ocr_result) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, diff --git a/tests/test_wired_table_rec.py b/tests/test_wired_table_rec.py index 44c3615..206bf24 100644 --- a/tests/test_wired_table_rec.py +++ b/tests/test_wired_table_rec.py @@ -6,7 +6,7 @@ import numpy as np import pytest from bs4 import BeautifulSoup -from rapidocr_onnxruntime import RapidOCR +from rapidocr import RapidOCR from wired_table_rec.main import WiredTableInput, ModelType from wired_table_rec.utils.utils import rescale_size @@ -41,8 +41,11 @@ def get_td_nums(html: str) -> int: def test_squeeze_bug(): img_path = test_file_dir / "squeeze_error.jpeg" - ocr_result, _ = ocr_engine(str(img_path)) - table_results = table_recog(str(img_path)) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_recog(str(img_path), ocr_result=ocr_result) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, @@ -63,8 +66,11 @@ def test_squeeze_bug(): def test_input_normal(img_path, gt_td_nums, gt2): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(str(img_path)) - table_results = table_recog(str(img_path)) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_recog(str(img_path), ocr_result=ocr_result) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, @@ -83,8 +89,13 @@ def test_input_normal(img_path, gt_td_nums, gt2): def test_enhance_box_line(img_path, gt_td_nums): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(str(img_path)) - table_results = table_recog(str(img_path), enhance_box_line=False) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_recog( + str(img_path), ocr_result=ocr_result, enhance_box_line=False + ) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, @@ -304,8 +315,11 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html): def test_no_rec_again(img_path, gt_td_nums, gt2): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(str(img_path)) - table_results = table_recog(str(img_path), rec_again=False) + rapid_ocr_output = ocr_engine(img_path, return_word_box=True) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_recog(str(img_path), ocr_result=ocr_result) table_html_str, table_cell_bboxes = ( table_results.pred_html, table_results.cell_bboxes, @@ -324,7 +338,6 @@ def test_no_rec_again(img_path, gt_td_nums, gt2): def test_no_ocr(img_path, html_output, points_len): img_path = test_file_dir / img_path - ocr_result, _ = ocr_engine(str(img_path)) table_results = table_recog(str(img_path), need_ocr=False) table_html_str, table_cell_bboxes, table_logic_points = ( table_results.pred_html, diff --git a/wired_table_rec/main.py b/wired_table_rec/main.py index 2afa8c3..161cb45 100644 --- a/wired_table_rec/main.py +++ b/wired_table_rec/main.py @@ -75,11 +75,6 @@ def __init__(self, config: WiredTableInput): self.table_recover = TableRecover() - try: - self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError: - self.ocr = None - def __call__( self, img: InputType, @@ -87,12 +82,10 @@ def __call__( **kwargs, ) -> WiredTableOutput: s = time.perf_counter() - rec_again = True need_ocr = True col_threshold = 15 row_threshold = 10 if kwargs: - rec_again = kwargs.get("rec_again", True) need_ocr = kwargs.get("need_ocr", True) col_threshold = kwargs.get("col_threshold", 15) row_threshold = kwargs.get("row_threshold", 10) @@ -121,11 +114,9 @@ def __call__( logi_points[idx_list], time.perf_counter() - s, ) - if ocr_result is None and need_ocr: - ocr_result, _ = self.ocr(img) cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons) # 如果有识别框没有ocr结果,直接进行rec补充 - cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again) + cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map) # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理 t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points) # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式 @@ -186,30 +177,19 @@ def sort_and_gather_ocr_res(self, res): ) return res - def re_rec( + def fill_blank_rec( self, img: np.ndarray, sorted_polygons: np.ndarray, cell_box_map: Dict[int, List[str]], - rec_again=True, ) -> Dict[int, List[Any]]: """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" for i in range(sorted_polygons.shape[0]): if cell_box_map.get(i): continue - if not rec_again: - box = sorted_polygons[i] - cell_box_map[i] = [[box, "", 1]] - continue - crop_img = get_rotate_crop_image(img, sorted_polygons[i]) - pad_img = cv2.copyMakeBorder( - crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255) - ) - rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True) box = sorted_polygons[i] - text = [rec[0] for rec in rec_res] - scores = [rec[1] for rec in rec_res] - cell_box_map[i] = [[box, "".join(text), min(scores)]] + cell_box_map[i] = [[box, "", 1]] + continue return cell_box_map def re_rec_high_precise( @@ -271,10 +251,10 @@ def main(): args = parser.parse_args() try: - ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() + ocr_engine = importlib.import_module("rapidocr").RapidOCR() except ModuleNotFoundError as exc: raise ModuleNotFoundError( - "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." + "Please install the rapidocr by pip install rapidocr." ) from exc input_args = WiredTableInput() table_rec = WiredTableRecognition(input_args) From 0bb5bb027435649249c96424eb9ced4410e8c30b Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Sun, 30 Mar 2025 20:07:13 +0800 Subject: [PATCH 8/8] chore: update readme --- README.md | 14 +++++++++----- README_en.md | 40 +++++++++++++++++++++------------------- demo_all.py | 7 +++---- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 5e9fc71..3ed2b33 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ pip install rapidocr ``` python {linenos=table} from pathlib import Path -from demo_wired import viser +from wired_table_rec.utils.utils import VisTable from table_cls import TableCls from wired_table_rec.main import WiredTableInput, WiredTableRecognition from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition @@ -104,6 +104,7 @@ if __name__ == "__main__": lineless_input = LinelessTableInput() wired_engine = WiredTableRecognition(wired_input) lineless_engine = LinelessTableRecognition(lineless_input) + viser = VisTable() # 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s) table_cls = TableCls() img_path = f"tests/test_files/table.jpg" @@ -117,12 +118,13 @@ if __name__ == "__main__": # 使用RapidOCR输入 ocr_engine = RapidOCR() rapid_ocr_output = ocr_engine(img_path, return_word_box=True) - ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) table_results = table_engine( - img_path, ocr_result=ocr_result, enhance_box_line=False + img_path, ocr_result=ocr_result ) - - + # 使用单字识别 # word_results = rapid_ocr_output.word_results # ocr_result = [ @@ -149,6 +151,8 @@ if __name__ == "__main__": + + ``` #### 单字ocr匹配 diff --git a/README_en.md b/README_en.md index ce491fe..c18bf90 100644 --- a/README_en.md +++ b/README_en.md @@ -91,7 +91,7 @@ pip install rapidocr ``` python {linenos=table} from pathlib import Path -from demo_wired import viser +from wired_table_rec.utils.utils import VisTable from table_cls import TableCls from wired_table_rec.main import WiredTableInput, WiredTableRecognition from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition @@ -104,7 +104,8 @@ if __name__ == "__main__": lineless_input = LinelessTableInput() wired_engine = WiredTableRecognition(wired_input) lineless_engine = LinelessTableRecognition(lineless_input) - # yolo(0.1s),yolox(0.25s),qanything(0.07s) paddle(0.03s) + viser = VisTable() + # 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s) table_cls = TableCls() img_path = f"tests/test_files/table.jpg" @@ -114,16 +115,17 @@ if __name__ == "__main__": else: table_engine = lineless_engine - # use rapid ocr as input + # 使用RapidOCR输入 ocr_engine = RapidOCR() rapid_ocr_output = ocr_engine(img_path, return_word_box=True) - ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) table_results = table_engine( - img_path, ocr_result=ocr_result, enhance_box_line=False + img_path, ocr_result=ocr_result ) - - - # use word rec ocr + + # 使用单字识别 # word_results = rapid_ocr_output.word_results # ocr_result = [ # [word_result[2], word_result[0], word_result[1]] for word_result in word_results @@ -133,19 +135,19 @@ if __name__ == "__main__": # ) # Save - # save_dir = Path("outputs") - # save_dir.mkdir(parents=True, exist_ok=True) - # - # save_html_path = f"outputs/{Path(img_path).stem}.html" - # save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" - # save_logic_path = ( - # f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" - # ) + #save_dir = Path("outputs") + #save_dir.mkdir(parents=True, exist_ok=True) + + #save_html_path = f"outputs/{Path(img_path).stem}.html" + #save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + #save_logic_path = ( + # f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}" + #) # Visualize table rec result - # vis_imged = viser( - # img_path, table_results, save_html_path, save_drawed_path, save_logic_path - # ) + #vis_imged = viser( + # img_path, table_results, save_html_path, save_drawed_path, save_logic_path + #) ``` #### Single Character OCR Matching diff --git a/demo_all.py b/demo_all.py index 653a947..8296b42 100644 --- a/demo_all.py +++ b/demo_all.py @@ -1,15 +1,16 @@ +from wired_table_rec.utils.utils import VisTable from table_cls import TableCls from wired_table_rec.main import WiredTableInput, WiredTableRecognition from lineless_table_rec.main import LinelessTableInput, LinelessTableRecognition from rapidocr import RapidOCR - if __name__ == "__main__": # Init wired_input = WiredTableInput() lineless_input = LinelessTableInput() wired_engine = WiredTableRecognition(wired_input) lineless_engine = LinelessTableRecognition(lineless_input) + viser = VisTable() # 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型或paddle模型(0.03s) table_cls = TableCls() img_path = f"tests/test_files/table.jpg" @@ -26,9 +27,7 @@ ocr_result = list( zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) ) - table_results = table_engine( - img_path, ocr_result=ocr_result, enhance_box_line=False - ) + table_results = table_engine(img_path, ocr_result=ocr_result) # 使用单字识别 # word_results = rapid_ocr_output.word_results