diff --git a/paperqa/agents/__init__.py b/paperqa/agents/__init__.py index d542c0162..9b45f0794 100644 --- a/paperqa/agents/__init__.py +++ b/paperqa/agents/__init__.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any +from aviary.utils import MultipleChoiceQuestion from pydantic_settings import CliSettingsSource from rich.logging import RichHandler @@ -97,7 +98,7 @@ def configure_cli_logging(verbosity: int | Settings = 0) -> None: print(f"PaperQA version: {__version__}") -def ask(query: str, settings: Settings) -> AnswerResponse: +def ask(query: str | MultipleChoiceQuestion, settings: Settings) -> AnswerResponse: """Query PaperQA via an agent.""" configure_cli_logging(settings) return get_loop().run_until_complete( @@ -109,7 +110,7 @@ def ask(query: str, settings: Settings) -> AnswerResponse: def search_query( - query: str, + query: str | MultipleChoiceQuestion, index_name: str, settings: Settings, ) -> list[tuple[AnswerResponse, str] | tuple[Any, str]]: @@ -119,7 +120,7 @@ def search_query( index_name = settings.get_index_name() return get_loop().run_until_complete( index_search( - query, + query if isinstance(query, str) else query.question_prompt, index_name=index_name, index_directory=settings.agent.index.index_directory, ) diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 5f4532813..460bcfc3e 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -11,6 +11,7 @@ ToolSelector, ToolSelectorLedger, ) +from aviary.utils import MultipleChoiceQuestion from pydantic import BaseModel from rich.console import Console from tenacity import ( @@ -49,12 +50,12 @@ async def agent_query( - query: str | QueryRequest, + query: str | MultipleChoiceQuestion | QueryRequest, docs: Docs | None = None, agent_type: str | type = DEFAULT_AGENT_TYPE, **runner_kwargs, ) -> AnswerResponse: - if isinstance(query, str): + if isinstance(query, str | MultipleChoiceQuestion): query = QueryRequest(query=query) if docs is None: docs = Docs()