Skip to content

Commit a4db19a

Browse files
authored
Consolidated LDP imports into ldp_shims module (#772)
1 parent d22eda1 commit a4db19a

File tree

5 files changed

+68
-49
lines changed

5 files changed

+68
-49
lines changed

paperqa/_ldp_shims.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Centralized place for lazy LDP imports."""
2+
3+
__all__ = [
4+
"HAS_LDP_INSTALLED",
5+
"Agent",
6+
"Callback",
7+
"ComputeTrajectoryMetricsMixin",
8+
"HTTPAgentClient",
9+
"Memory",
10+
"MemoryAgent",
11+
"ReActAgent",
12+
"RolloutManager",
13+
"SimpleAgent",
14+
"SimpleAgentState",
15+
"UIndexMemoryModel",
16+
"_Memories",
17+
"discounted_returns",
18+
"set_training_mode",
19+
]
20+
21+
from pydantic import TypeAdapter
22+
23+
try:
24+
from ldp.agent import (
25+
Agent,
26+
HTTPAgentClient,
27+
MemoryAgent,
28+
ReActAgent,
29+
SimpleAgent,
30+
SimpleAgentState,
31+
)
32+
from ldp.alg import Callback, ComputeTrajectoryMetricsMixin, RolloutManager
33+
from ldp.graph.memory import Memory, UIndexMemoryModel
34+
from ldp.graph.op_utils import set_training_mode
35+
from ldp.utils import discounted_returns
36+
37+
_Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated]
38+
39+
HAS_LDP_INSTALLED = True
40+
except ImportError:
41+
HAS_LDP_INSTALLED = False
42+
43+
class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
44+
"""Placeholder parent class for when ldp isn't installed."""
45+
46+
class Callback: # type: ignore[no-redef]
47+
"""Placeholder parent class for when ldp isn't installed."""
48+
49+
RolloutManager = None # type: ignore[assignment,misc]
50+
discounted_returns = None # type: ignore[assignment]

paperqa/agents/main.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,7 @@
2020
stop_after_attempt,
2121
)
2222

23-
try:
24-
from ldp.alg import Callback, RolloutManager
25-
except ImportError:
26-
27-
class Callback: # type: ignore[no-redef]
28-
"""Placeholder parent class for when ldp isn't installed."""
29-
30-
RolloutManager = None # type: ignore[assignment,misc]
31-
23+
from paperqa._ldp_shims import Callback, RolloutManager
3224
from paperqa.docs import Docs
3325
from paperqa.settings import AgentSettings
3426
from paperqa.types import PQASession

paperqa/agents/task.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,9 @@
2424
ToolResponseMessage,
2525
)
2626
from aviary.env import ENV_REGISTRY
27-
28-
from paperqa.types import DocDetails
29-
30-
from .search import SearchIndex, maybe_get_manifest
31-
32-
try:
33-
from ldp.alg import ComputeTrajectoryMetricsMixin
34-
except ImportError:
35-
36-
class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
37-
"""Placeholder for when ldp isn't installed."""
38-
39-
4027
from llmclient import EmbeddingModel, LiteLLMModel, LLMModel
4128

29+
from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin
4230
from paperqa.docs import Docs
4331
from paperqa.litqa import (
4432
DEFAULT_EVAL_MODEL_NAME,
@@ -47,10 +35,11 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
4735
LitQAEvaluation,
4836
read_litqa_v2_from_hub,
4937
)
50-
from paperqa.types import PQASession
38+
from paperqa.types import DocDetails, PQASession
5139

5240
from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
5341
from .models import QueryRequest
42+
from .search import SearchIndex, maybe_get_manifest
5443
from .tools import Complete
5544

5645
if TYPE_CHECKING:

paperqa/litqa.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,10 @@
1010
from enum import StrEnum
1111
from typing import TYPE_CHECKING, Literal, Self
1212

13-
try:
14-
from ldp.utils import discounted_returns
15-
except ImportError:
16-
discounted_returns = None # type: ignore[assignment]
17-
1813
from aviary.core import Message
1914
from llmclient import LiteLLMModel, LLMModel
2015

16+
from paperqa._ldp_shims import discounted_returns
2117
from paperqa.prompts import EVAL_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE
2218
from paperqa.settings import make_default_litellm_model_list_settings
2319
from paperqa.types import PQASession

paperqa/settings.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,29 @@
1010

1111
import anyio
1212
from aviary.core import ToolSelector
13+
from llmclient import EmbeddingModel, LiteLLMModel, embedding_model_factory
1314
from pydantic import (
1415
BaseModel,
1516
ConfigDict,
1617
Field,
17-
TypeAdapter,
1818
computed_field,
1919
field_validator,
2020
model_validator,
2121
)
2222
from pydantic_settings import BaseSettings, CliSettingsSource, SettingsConfigDict
2323

24-
try:
25-
from ldp.agent import (
26-
Agent,
27-
HTTPAgentClient,
28-
MemoryAgent,
29-
ReActAgent,
30-
SimpleAgent,
31-
SimpleAgentState,
32-
)
33-
from ldp.graph.memory import Memory, UIndexMemoryModel
34-
from ldp.graph.op_utils import set_training_mode
35-
36-
_Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated]
37-
38-
HAS_LDP_INSTALLED = True
39-
except ImportError:
40-
HAS_LDP_INSTALLED = False
41-
42-
from llmclient import EmbeddingModel, LiteLLMModel, embedding_model_factory
43-
24+
from paperqa._ldp_shims import (
25+
HAS_LDP_INSTALLED,
26+
Agent,
27+
HTTPAgentClient,
28+
MemoryAgent,
29+
ReActAgent,
30+
SimpleAgent,
31+
SimpleAgentState,
32+
UIndexMemoryModel,
33+
_Memories,
34+
set_training_mode,
35+
)
4436
from paperqa.prompts import (
4537
CONTEXT_INNER_PROMPT,
4638
CONTEXT_OUTER_PROMPT,

0 commit comments

Comments
 (0)