5 files changed +52
-2
lines changed Original file line number Diff line number Diff line change 10
10
from graph_retriever .content import Content
11
11
from graph_retriever .types import Node
12
12
13
+ DEFAULT_SELECT_K = 5
14
+
13
15
14
16
class NodeTracker :
15
17
"""Helper class for tracking traversal progress."""
@@ -95,16 +97,27 @@ class Strategy(abc.ABC):
95
97
If `None`, there is no limit.
96
98
max_depth :
97
99
Maximum traversal depth. If `None`, there is no limit.
100
+ k:
101
+ Deprecated: Use `select_k` instead.
102
+ Maximum number of nodes to select and return during traversal.
98
103
"""
99
104
100
- select_k : int = 5
105
+ select_k : int = dataclasses . field ( default = DEFAULT_SELECT_K )
101
106
start_k : int = 4
102
107
adjacent_k : int = 10
103
108
max_traverse : int | None = None
104
109
max_depth : int | None = None
110
+ k : int = dataclasses .field (default = DEFAULT_SELECT_K , repr = False )
105
111
106
112
_query_embedding : list [float ] = dataclasses .field (default_factory = list )
107
113
114
+ def __post_init__ (self ):
115
+ """Allow passing the deprecated 'k' value instead of 'select_k'."""
116
+ if self .select_k == DEFAULT_SELECT_K and self .k != DEFAULT_SELECT_K :
117
+ self .select_k = self .k
118
+ else :
119
+ self .k = self .select_k
120
+
108
121
@abc .abstractmethod
109
122
def iteration (self , * , nodes : Iterable [Node ], tracker : NodeTracker ) -> None :
110
123
"""
Original file line number Diff line number Diff line change @@ -30,6 +30,9 @@ class Eager(Strategy):
30
30
Number of documents to fetch for each outgoing edge.
31
31
max_depth :
32
32
Maximum traversal depth. If `None`, there is no limit.
33
+ k:
34
+ Deprecated: Use `select_k` instead.
35
+ Maximum number of nodes to select and return during traversal.
33
36
"""
34
37
35
38
@override
Original file line number Diff line number Diff line change @@ -67,6 +67,9 @@ class Mmr(Strategy):
67
67
min_mmr_score :
68
68
Only nodes with a score greater than or equal to this value will be
69
69
selected.
70
+ k:
71
+ Deprecated: Use `select_k` instead.
72
+ Maximum number of nodes to select and return during traversal.
70
73
"""
71
74
72
75
lambda_mult : float = 0.5
Original file line number Diff line number Diff line change @@ -19,7 +19,33 @@ def __lt__(self, other: "_ScoredNode") -> bool:
19
19
20
20
@dataclasses .dataclass
21
21
class Scored (Strategy ):
22
- """Strategy selecting nodes using a scoring function."""
22
+ """
23
+ Scored traversal strategy.
24
+
25
+ This strategy uses a scoring function to select nodes using a local maximum
26
+ approach. In each iteration, it chooses the top scoring nodes available and
27
+ then traverses the connected nodes.
28
+
29
+ Parameters
30
+ ----------
31
+ scorer:
32
+ A callable function that returns the score of a node.
33
+ select_k :
34
+ Maximum number of nodes to retrieve during traversal.
35
+ start_k :
36
+ Number of documents to fetch via similarity for starting the traversal.
37
+ Added to any initial roots provided to the traversal.
38
+ adjacent_k :
39
+ Number of documents to fetch for each outgoing edge.
40
+ max_depth :
41
+ Maximum traversal depth. If `None`, there is no limit.
42
+ per_iteration_limit:
43
+ Maximum number of nodes to select and traverse during a single
44
+ iteration.
45
+ k:
46
+ Deprecated: Use `select_k` instead.
47
+ Maximum number of nodes to select and return during traversal.
48
+ """
23
49
24
50
scorer : Callable [[Node ], float ]
25
51
_nodes : list [_ScoredNode ] = dataclasses .field (default_factory = list )
Original file line number Diff line number Diff line change @@ -102,3 +102,8 @@ def test_build_strategy_base_override_mmr():
102
102
Strategy .build (
103
103
base_strategy = override_strategy , strategy = base_strategy , lambda_mult = 0.2
104
104
)
105
+
106
+
107
+ def test_setting_k_sets_select_k ():
108
+ assert Eager (select_k = 4 ) == Eager (k = 4 )
109
+ assert Mmr (select_k = 3 ) == Mmr (k = 3 )
0 commit comments