From 628edb3c683e3e4abf577c71e52b48f211bd9649 Mon Sep 17 00:00:00 2001 From: liningping <728359849@qq.com> Date: Thu, 21 Aug 2025 07:38:08 +0000 Subject: [PATCH 1/3] feat: add enhancer registry --- enhancers/__init__.py | 17 +++++ enhancers/base_models.py | 32 +++++++++ enhancers/enhancer_registry.py | 105 ++++++++++++++++++++++++++++++ enhancers/information_enhancer.py | 62 ------------------ parsers/__init__.py | 4 +- worker.py | 38 ++++++----- 6 files changed, 179 insertions(+), 79 deletions(-) create mode 100644 enhancers/base_models.py create mode 100644 enhancers/enhancer_registry.py delete mode 100644 enhancers/information_enhancer.py diff --git a/enhancers/__init__.py b/enhancers/__init__.py index e69de29..0509a18 100644 --- a/enhancers/__init__.py +++ b/enhancers/__init__.py @@ -0,0 +1,17 @@ +from .enhancer_registry import ( + ENHANCER_REGISTRY, + get_enhancer, + get_enhancer_class, + get_supported_modalities, + list_registered_enhancers, + register_enhancer, +) + +__all__ = [ + "ENHANCER_REGISTRY", + "get_enhancer", + "get_enhancer_class", + "get_supported_modalities", + "list_registered_enhancers", + "register_enhancer", +] diff --git a/enhancers/base_models.py b/enhancers/base_models.py new file mode 100644 index 0000000..ba77fa7 --- /dev/null +++ b/enhancers/base_models.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod + +from parsers.base_models import ChunkData + + +class InformationEnhancer(ABC): + """信息增强器基类""" + @abstractmethod + async def enhance(self, information: ChunkData) -> ChunkData: + """增强信息""" + pass + +class TableInformationEnhancer(InformationEnhancer): + """表格信息增强器""" + + async def enhance(self, information: ChunkData) -> ChunkData: + """增强信息""" + return information + +class FormulasInformationEnhancer(InformationEnhancer): + """公式信息增强器""" + + async def enhance(self, information: ChunkData) -> ChunkData: + """增强信息""" + return information + +class ImageInformationEnhancer(InformationEnhancer): + """图片信息增强器""" + + async def enhance(self, information: ChunkData) -> ChunkData: + """增强信息""" + return information diff --git a/enhancers/enhancer_registry.py b/enhancers/enhancer_registry.py new file mode 100644 index 0000000..e637c1c --- /dev/null +++ b/enhancers/enhancer_registry.py @@ -0,0 +1,105 @@ +""" +解析器注册器模块 + +提供基于装饰器的解析器自动注册机制,支持多种文件格式的解析器注册和查找。 +""" + +import logging +from collections.abc import Callable + +from enhancers.base_models import InformationEnhancer +from parsers.base_models import ChunkType + +logger = logging.getLogger(__name__) + +# 全局解析器注册表 +ENHANCER_REGISTRY: dict[str, type[InformationEnhancer]] = {} + + +def register_enhancer(modalities: list[str]) -> Callable[[type[InformationEnhancer]], type[InformationEnhancer]]: + """ + 信息增强器注册装饰器 + + Args: + modalities: 支持的模态类型列表,如 [ChunkType.TEXT, ChunkType.IMAGE, ChunkType.TABLE] + + Returns: + 装饰器函数 + + Example: + @register_enhancer([ChunkType.TEXT, ChunkType.IMAGE, ChunkType.TABLE]) + class TextInformationEnhancer(InformationEnhancer): + ... + """ + def decorator(cls: type[InformationEnhancer]) -> type[InformationEnhancer]: + # 验证类是否继承自 InformationEnhancer + if not issubclass(cls, InformationEnhancer): + raise TypeError(f"信息增强器类 {cls.__name__} 必须继承自 InformationEnhancer") + + # 注册到全局注册表 + for modality in modalities: + modality = modality.lower() # 统一转换为小写 + if modality in ENHANCER_REGISTRY: + logger.warning(f"覆盖已存在的信息增强器: {modality} -> {cls.__name__}") + ENHANCER_REGISTRY[modality] = cls + logger.info(f"注册信息增强器: {modality} -> {cls.__name__}") + + return cls + + return decorator + + +def get_enhancer(modality: ChunkType) -> InformationEnhancer | None: + """ + 根据模态类型获取合适的信息增强器实例 + + Args: + modality: 模态类型 + + Returns: + 信息增强器实例,如果没有找到则返回 None + """ + modality_type = modality.value.lower() + + if modality_type not in ENHANCER_REGISTRY: + logger.warning(f"未找到支持 {modality} 格式的信息增强器") + return None + + enhancer_class = ENHANCER_REGISTRY[modality_type] + try: + return enhancer_class() + except Exception as e: + logger.error(f"创建信息增强器实例失败: {enhancer_class.__name__}, 错误: {e}") + return None + +def get_supported_modalities() -> list[str]: + """ + 获取所有支持的模态类型 + + Returns: + 支持的模态类型列表 + """ + return list(ENHANCER_REGISTRY.keys()) + + +def get_enhancer_class(modality: ChunkType) -> type[InformationEnhancer] | None: + """ + 根据模态类型获取信息增强器类 + + Args: + modality: 模态类型 + + Returns: + 信息增强器类,如果没有找到则返回 None + """ + return ENHANCER_REGISTRY.get(modality.value.lower()) + + +def list_registered_enhancers() -> dict[str, str]: + """ + 列出所有已注册的信息增强器 + + Returns: + 模态类型到信息增强器类名的映射字典 + """ + return {modality: cls.__name__ for modality, cls in ENHANCER_REGISTRY.items()} diff --git a/enhancers/information_enhancer.py b/enhancers/information_enhancer.py deleted file mode 100644 index aeb8758..0000000 --- a/enhancers/information_enhancer.py +++ /dev/null @@ -1,62 +0,0 @@ -from abc import ABC, abstractmethod - -from parsers.base_models import ChunkData, ChunkType - - -class InformationEnhancer(ABC): - """信息增强器基类""" - @abstractmethod - async def enhance(self, information: ChunkData) -> ChunkData: - """增强信息""" - pass - -class TableInformationEnhancer(InformationEnhancer): - """表格信息增强器""" - - async def enhance(self, information: ChunkData) -> ChunkData: - """增强信息""" - return information - -class FormulasInformationEnhancer(InformationEnhancer): - """公式信息增强器""" - - async def enhance(self, information: ChunkData) -> ChunkData: - """增强信息""" - return information - -class ImageInformationEnhancer(InformationEnhancer): - """图片信息增强器""" - - async def enhance(self, information: ChunkData) -> ChunkData: - """增强信息""" - return information - -class InformationEnhancerFactory: - """信息增强器工厂""" - - def __init__(self) -> None: - self.enhancers = [ - TableInformationEnhancer(), - FormulasInformationEnhancer(), - ImageInformationEnhancer() - ] - - def get_enhancer(self, information: ChunkData) -> InformationEnhancer|None: - """获取信息增强器""" - match information.type: - case ChunkType.TABLE: - return TableInformationEnhancer() - case ChunkType.FORMULA: - return FormulasInformationEnhancer() - case ChunkType.IMAGE: - return ImageInformationEnhancer() - case _: - return None - - async def enhance_information(self, information: ChunkData) -> ChunkData: - """增强信息""" - enhancer = self.get_enhancer(information) - if not enhancer: - raise ValueError(f"不支持的模态类型: {information.type}") - return await enhancer.enhance(information) - diff --git a/parsers/__init__.py b/parsers/__init__.py index 9a7e20d..8f756b5 100644 --- a/parsers/__init__.py +++ b/parsers/__init__.py @@ -1,6 +1,6 @@ # Parsers package -from .base_models import DocumentData, DocumentParser +from .base_models import ChunkData, ChunkType, DocumentData, DocumentParser from .parser_registry import ( PARSER_REGISTRY, get_parser, @@ -12,6 +12,8 @@ __all__ = [ 'DocumentData', 'DocumentParser', + 'ChunkData', + 'ChunkType', 'PARSER_REGISTRY', 'register_parser', 'get_parser', diff --git a/worker.py b/worker.py index 44e9d76..a3e2366 100644 --- a/worker.py +++ b/worker.py @@ -3,15 +3,13 @@ from sanic import Sanic -from enhancers.information_enhancer import InformationEnhancerFactory -from parsers import get_parser, load_all_parsers -from parsers.base_models import ChunkData +from enhancers import get_enhancer +from parsers import ChunkData, ChunkType, get_parser, load_all_parsers async def worker(app: Sanic) -> dict[str, Any]: # 使用工厂获取合适的解析器 load_all_parsers() - enhancer_factory = InformationEnhancerFactory() redis = app.ctx.redis while True: task = await redis.get_task() @@ -25,21 +23,29 @@ async def worker(app: Sanic) -> dict[str, Any]: parse_result = await parser.parse(file_path) if not parse_result.success: continue - chunk_list = parse_result.texts + parse_result.tables + parse_result.images + parse_result.formulas # 控制并发数量,防止访问量过大导致失败 - SEMAPHORE_LIMIT = 10 # 可根据实际情况调整 + SEMAPHORE_LIMIT = 10 semaphore = asyncio.Semaphore(SEMAPHORE_LIMIT) async def enhance_with_semaphore(chunk: ChunkData, semaphore: asyncio.Semaphore) -> ChunkData: async with semaphore: - return await enhancer_factory.enhance_information(chunk) - - # 并发增强每个信息 - enhanced_chunk_list = await asyncio.gather( - *(enhance_with_semaphore(chunk, semaphore) for chunk in chunk_list) - ) - parse_result.texts = enhanced_chunk_list[:len(parse_result.texts)] - parse_result.tables = enhanced_chunk_list[len(parse_result.texts):len(parse_result.texts) + len(parse_result.tables)] - parse_result.images = enhanced_chunk_list[len(parse_result.texts) + len(parse_result.tables):len(parse_result.texts) + len(parse_result.tables) + len(parse_result.images)] - parse_result.formulas = enhanced_chunk_list[len(parse_result.texts) + len(parse_result.tables) + len(parse_result.images):] + enhancer = get_enhancer(ChunkType(chunk.type)) + if not enhancer: + return chunk + return await enhancer.enhance(chunk) + + text_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.texts] + table_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.tables] + image_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.images] + formula_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.formulas] + + text_chunk_list = await asyncio.gather(*text_tasks) + table_chunk_list = await asyncio.gather(*table_tasks) + image_chunk_list = await asyncio.gather(*image_tasks) + formula_chunk_list = await asyncio.gather(*formula_tasks) + + parse_result.texts = text_chunk_list + parse_result.tables = table_chunk_list + parse_result.images = image_chunk_list + parse_result.formulas = formula_chunk_list return parse_result.model_dump(mode="json") From f4eb8c13fcb6b4b845067dc2d161eec05fade5d0 Mon Sep 17 00:00:00 2001 From: liningping <728359849@qq.com> Date: Thu, 21 Aug 2025 07:52:46 +0000 Subject: [PATCH 2/3] fix: ban cover enhancer --- enhancers/enhancer_registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/enhancers/enhancer_registry.py b/enhancers/enhancer_registry.py index e637c1c..6b60504 100644 --- a/enhancers/enhancer_registry.py +++ b/enhancers/enhancer_registry.py @@ -16,7 +16,7 @@ ENHANCER_REGISTRY: dict[str, type[InformationEnhancer]] = {} -def register_enhancer(modalities: list[str]) -> Callable[[type[InformationEnhancer]], type[InformationEnhancer]]: +def register_enhancer(modalities: list[ChunkType]) -> Callable[[type[InformationEnhancer]], type[InformationEnhancer]]: """ 信息增强器注册装饰器 @@ -38,9 +38,10 @@ def decorator(cls: type[InformationEnhancer]) -> type[InformationEnhancer]: # 注册到全局注册表 for modality in modalities: - modality = modality.lower() # 统一转换为小写 + modality = modality.value.lower() # 统一转换为小写 if modality in ENHANCER_REGISTRY: - logger.warning(f"覆盖已存在的信息增强器: {modality} -> {cls.__name__}") + logger.error(f"覆盖已存在的信息增强器: {modality} -> {cls.__name__}") + raise ValueError(f"尝试覆盖已存在的信息增强器: {modality} -> {cls.__name__}") ENHANCER_REGISTRY[modality] = cls logger.info(f"注册信息增强器: {modality} -> {cls.__name__}") @@ -48,7 +49,6 @@ def decorator(cls: type[InformationEnhancer]) -> type[InformationEnhancer]: return decorator - def get_enhancer(modality: ChunkType) -> InformationEnhancer | None: """ 根据模态类型获取合适的信息增强器实例 From 1f13645569b27f0d3789d75f7e6eee1e4070a5e1 Mon Sep 17 00:00:00 2001 From: liningping <728359849@qq.com> Date: Thu, 21 Aug 2025 07:58:44 +0000 Subject: [PATCH 3/3] fix:comfort mypy --- enhancers/enhancer_registry.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/enhancers/enhancer_registry.py b/enhancers/enhancer_registry.py index 6b60504..58a8d0a 100644 --- a/enhancers/enhancer_registry.py +++ b/enhancers/enhancer_registry.py @@ -38,12 +38,12 @@ def decorator(cls: type[InformationEnhancer]) -> type[InformationEnhancer]: # 注册到全局注册表 for modality in modalities: - modality = modality.value.lower() # 统一转换为小写 - if modality in ENHANCER_REGISTRY: - logger.error(f"覆盖已存在的信息增强器: {modality} -> {cls.__name__}") - raise ValueError(f"尝试覆盖已存在的信息增强器: {modality} -> {cls.__name__}") - ENHANCER_REGISTRY[modality] = cls - logger.info(f"注册信息增强器: {modality} -> {cls.__name__}") + modality_type = modality.value.lower() # 统一转换为小写 + if modality_type in ENHANCER_REGISTRY: + logger.error(f"覆盖已存在的信息增强器: {modality_type} -> {cls.__name__}") + raise ValueError(f"尝试覆盖已存在的信息增强器: {modality_type} -> {cls.__name__}") + ENHANCER_REGISTRY[modality_type] = cls + logger.info(f"注册信息增强器: {modality_type} -> {cls.__name__}") return cls