Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/paperqa/contrib/zotero.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""This module gets PDF files from the user's Zotero library."""

import asyncio
import logging
import os
from collections.abc import Awaitable
from pathlib import Path
from typing import cast

Expand Down Expand Up @@ -256,6 +258,8 @@ def iterate( # noqa: PLR0912
title = item["data"].get("title", "")
if len(items) >= start:
parsed_text = self._parse_pdf(pdf)
if isinstance(parsed_text, Awaitable):
parsed_text = asyncio.run(parsed_text)
if not isinstance(parsed_text.content, dict):
raise ValueError(
"The content type coming from the PDF parser"
Expand Down
30 changes: 26 additions & 4 deletions src/paperqa/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from importlib.metadata import version
from math import ceil
from pathlib import Path
from typing import Literal, Protocol, cast, overload, runtime_checkable
from typing import Literal, Protocol, TypeAlias, cast, overload, runtime_checkable

import anyio
import tiktoken
from aviary.core import is_coroutine_callable
from html2text import __version__ as html2text_version
from html2text import html2text

Expand All @@ -26,8 +27,8 @@


@runtime_checkable
class PDFParserFn(Protocol):
"""Protocol for parsing a PDF."""
class SyncPDFParserFn(Protocol):
"""Protocol for synchronously parsing a PDF."""

def __call__(
self,
Expand All @@ -38,6 +39,22 @@ def __call__(
) -> ParsedText: ...


@runtime_checkable
class AsyncPDFParserFn(Protocol):
"""Protocol for asynchronously parsing a PDF."""

async def __call__(
self,
path: str | os.PathLike,
page_size_limit: int | None = None,
page_range: int | tuple[int, int] | None = None,
**kwargs,
) -> ParsedText: ...


PDFParserFn: TypeAlias = SyncPDFParserFn | AsyncPDFParserFn


async def parse_image(
path: str | os.PathLike, validator: Callable[[bytes], Awaitable] | None = None, **_
) -> ParsedText:
Expand Down Expand Up @@ -429,7 +446,12 @@ async def read_doc( # noqa: PLR0912
raise ValueError("When parsing a PDF, a parsing function must be provided.")
# Some PDF parsers are not thread-safe,
# so can't use multithreading via `asyncio.to_thread` here
parsed_text: ParsedText = parse_pdf(path, **parser_kwargs)
if is_coroutine_callable(parse_pdf):
parsed_text: ParsedText = await cast(AsyncPDFParserFn, parse_pdf)(
path, **parser_kwargs
)
else:
parsed_text = cast(SyncPDFParserFn, parse_pdf)(path, **parser_kwargs)
elif str_path.endswith(".txt"):
# TODO: Make parse_text async
parsed_text = await asyncio.to_thread(parse_text, path, **parser_kwargs)
Expand Down