Skip to content

Commit af34d82

Browse files
authoredMar 6, 2025
restored 'k' parameter when creating strategies (#159)
* restored 'k' parameter when creating strategies * lint
1 parent 1a92a14 commit af34d82

File tree

5 files changed

+52
-2
lines changed

5 files changed

+52
-2
lines changed
 

‎packages/graph-retriever/src/graph_retriever/strategies/base.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from graph_retriever.content import Content
1111
from graph_retriever.types import Node
1212

13+
DEFAULT_SELECT_K = 5
14+
1315

1416
class NodeTracker:
1517
"""Helper class for tracking traversal progress."""
@@ -95,16 +97,27 @@ class Strategy(abc.ABC):
9597
If `None`, there is no limit.
9698
max_depth :
9799
Maximum traversal depth. If `None`, there is no limit.
100+
k:
101+
Deprecated: Use `select_k` instead.
102+
Maximum number of nodes to select and return during traversal.
98103
"""
99104

100-
select_k: int = 5
105+
select_k: int = dataclasses.field(default=DEFAULT_SELECT_K)
101106
start_k: int = 4
102107
adjacent_k: int = 10
103108
max_traverse: int | None = None
104109
max_depth: int | None = None
110+
k: int = dataclasses.field(default=DEFAULT_SELECT_K, repr=False)
105111

106112
_query_embedding: list[float] = dataclasses.field(default_factory=list)
107113

114+
def __post_init__(self):
115+
"""Allow passing the deprecated 'k' value instead of 'select_k'."""
116+
if self.select_k == DEFAULT_SELECT_K and self.k != DEFAULT_SELECT_K:
117+
self.select_k = self.k
118+
else:
119+
self.k = self.select_k
120+
108121
@abc.abstractmethod
109122
def iteration(self, *, nodes: Iterable[Node], tracker: NodeTracker) -> None:
110123
"""

‎packages/graph-retriever/src/graph_retriever/strategies/eager.py

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class Eager(Strategy):
3030
Number of documents to fetch for each outgoing edge.
3131
max_depth :
3232
Maximum traversal depth. If `None`, there is no limit.
33+
k:
34+
Deprecated: Use `select_k` instead.
35+
Maximum number of nodes to select and return during traversal.
3336
"""
3437

3538
@override

‎packages/graph-retriever/src/graph_retriever/strategies/mmr.py

+3
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class Mmr(Strategy):
6767
min_mmr_score :
6868
Only nodes with a score greater than or equal to this value will be
6969
selected.
70+
k:
71+
Deprecated: Use `select_k` instead.
72+
Maximum number of nodes to select and return during traversal.
7073
"""
7174

7275
lambda_mult: float = 0.5

‎packages/graph-retriever/src/graph_retriever/strategies/scored.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,33 @@ def __lt__(self, other: "_ScoredNode") -> bool:
1919

2020
@dataclasses.dataclass
2121
class Scored(Strategy):
22-
"""Strategy selecting nodes using a scoring function."""
22+
"""
23+
Scored traversal strategy.
24+
25+
This strategy uses a scoring function to select nodes using a local maximum
26+
approach. In each iteration, it chooses the top scoring nodes available and
27+
then traverses the connected nodes.
28+
29+
Parameters
30+
----------
31+
scorer:
32+
A callable function that returns the score of a node.
33+
select_k :
34+
Maximum number of nodes to retrieve during traversal.
35+
start_k :
36+
Number of documents to fetch via similarity for starting the traversal.
37+
Added to any initial roots provided to the traversal.
38+
adjacent_k :
39+
Number of documents to fetch for each outgoing edge.
40+
max_depth :
41+
Maximum traversal depth. If `None`, there is no limit.
42+
per_iteration_limit:
43+
Maximum number of nodes to select and traverse during a single
44+
iteration.
45+
k:
46+
Deprecated: Use `select_k` instead.
47+
Maximum number of nodes to select and return during traversal.
48+
"""
2349

2450
scorer: Callable[[Node], float]
2551
_nodes: list[_ScoredNode] = dataclasses.field(default_factory=list)

‎packages/graph-retriever/tests/strategies/test_base.py

+5
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,8 @@ def test_build_strategy_base_override_mmr():
102102
Strategy.build(
103103
base_strategy=override_strategy, strategy=base_strategy, lambda_mult=0.2
104104
)
105+
106+
107+
def test_setting_k_sets_select_k():
108+
assert Eager(select_k=4) == Eager(k=4)
109+
assert Mmr(select_k=3) == Mmr(k=3)

0 commit comments

Comments
 (0)
Failed to load comments.