diff --git a/src/paperqa/contrib/zotero.py b/src/paperqa/contrib/zotero.py index 237858e8a..9644d070f 100644 --- a/src/paperqa/contrib/zotero.py +++ b/src/paperqa/contrib/zotero.py @@ -1,7 +1,9 @@ """This module gets PDF files from the user's Zotero library.""" +import asyncio import logging import os +from collections.abc import Awaitable from pathlib import Path from typing import cast @@ -256,6 +258,8 @@ def iterate( # noqa: PLR0912 title = item["data"].get("title", "") if len(items) >= start: parsed_text = self._parse_pdf(pdf) + if isinstance(parsed_text, Awaitable): + parsed_text = asyncio.run(parsed_text) if not isinstance(parsed_text.content, dict): raise ValueError( "The content type coming from the PDF parser" diff --git a/src/paperqa/readers.py b/src/paperqa/readers.py index dd3e51c10..3c778a8ba 100644 --- a/src/paperqa/readers.py +++ b/src/paperqa/readers.py @@ -6,10 +6,11 @@ from importlib.metadata import version from math import ceil from pathlib import Path -from typing import Literal, Protocol, cast, overload, runtime_checkable +from typing import Literal, Protocol, TypeAlias, cast, overload, runtime_checkable import anyio import tiktoken +from aviary.core import is_coroutine_callable from html2text import __version__ as html2text_version from html2text import html2text @@ -26,8 +27,8 @@ @runtime_checkable -class PDFParserFn(Protocol): - """Protocol for parsing a PDF.""" +class SyncPDFParserFn(Protocol): + """Protocol for synchronously parsing a PDF.""" def __call__( self, @@ -38,6 +39,22 @@ def __call__( ) -> ParsedText: ... +@runtime_checkable +class AsyncPDFParserFn(Protocol): + """Protocol for asynchronously parsing a PDF.""" + + async def __call__( + self, + path: str | os.PathLike, + page_size_limit: int | None = None, + page_range: int | tuple[int, int] | None = None, + **kwargs, + ) -> ParsedText: ... + + +PDFParserFn: TypeAlias = SyncPDFParserFn | AsyncPDFParserFn + + async def parse_image( path: str | os.PathLike, validator: Callable[[bytes], Awaitable] | None = None, **_ ) -> ParsedText: @@ -429,7 +446,12 @@ async def read_doc( # noqa: PLR0912 raise ValueError("When parsing a PDF, a parsing function must be provided.") # Some PDF parsers are not thread-safe, # so can't use multithreading via `asyncio.to_thread` here - parsed_text: ParsedText = parse_pdf(path, **parser_kwargs) + if is_coroutine_callable(parse_pdf): + parsed_text: ParsedText = await cast(AsyncPDFParserFn, parse_pdf)( + path, **parser_kwargs + ) + else: + parsed_text = cast(SyncPDFParserFn, parse_pdf)(path, **parser_kwargs) elif str_path.endswith(".txt"): # TODO: Make parse_text async parsed_text = await asyncio.to_thread(parse_text, path, **parser_kwargs)