2
2
SmartScraperGraph Module
3
3
"""
4
4
from typing import Optional
5
- import logging
6
5
from pydantic import BaseModel
7
6
from .base_graph import BaseGraph
8
7
from .abstract_graph import AbstractGraph
9
8
from ..nodes import (
10
9
FetchNode ,
11
10
ParseNode ,
12
11
ReasoningNode ,
13
- GenerateAnswerNode
12
+ GenerateAnswerNode ,
13
+ ConditionalNode
14
14
)
15
+ from ..prompts import REGEN_ADDITIONAL_INFO
15
16
16
17
class SmartScraperGraph (AbstractGraph ):
17
18
"""
@@ -89,6 +90,28 @@ def _create_graph(self) -> BaseGraph:
89
90
}
90
91
)
91
92
93
+ cond_node = None
94
+ regen_node = None
95
+ if self .config .get ("reattempt" ) is True :
96
+ cond_node = ConditionalNode (
97
+ input = "results" ,
98
+ output = ["results" ],
99
+ node_name = "ConditionalNode" ,
100
+ node_config = {
101
+ "key_name" : "results" ,
102
+ "condition" : 'results and results!="NA"' ,
103
+ }
104
+ )
105
+ regen_node = GenerateAnswerNode (
106
+ input = "user_prompt & results" ,
107
+ output = ["answer" ],
108
+ node_config = {
109
+ "llm_model" : self .llm_model ,
110
+ "additional_info" : REGEN_ADDITIONAL_INFO ,
111
+ "schema" : self .schema ,
112
+ }
113
+ )
114
+
92
115
if self .config .get ("html_mode" ) is False :
93
116
parse_node = ParseNode (
94
117
input = "doc" ,
@@ -99,6 +122,7 @@ def _create_graph(self) -> BaseGraph:
99
122
}
100
123
)
101
124
125
+ reasoning_node = None
102
126
if self .config .get ("reasoning" ):
103
127
reasoning_node = ReasoningNode (
104
128
input = "user_prompt & (relevant_chunks | parsed_doc | doc)" ,
@@ -109,68 +133,72 @@ def _create_graph(self) -> BaseGraph:
109
133
"schema" : self .schema ,
110
134
}
111
135
)
136
+
137
+ # Define the graph variation configurations
138
+ # (html_mode, reasoning, reattempt)
139
+ graph_variation_config = {
140
+ (False , True , False ): {
141
+ "nodes" : [fetch_node , parse_node , reasoning_node , generate_answer_node ],
142
+ "edges" : [(fetch_node , parse_node ), (parse_node , reasoning_node ), (reasoning_node , generate_answer_node )]
143
+ },
144
+ (True , True , False ): {
145
+ "nodes" : [fetch_node , reasoning_node , generate_answer_node ],
146
+ "edges" : [(fetch_node , reasoning_node ), (reasoning_node , generate_answer_node )]
147
+ },
148
+ (True , False , False ): {
149
+ "nodes" : [fetch_node , generate_answer_node ],
150
+ "edges" : [(fetch_node , generate_answer_node )]
151
+ },
152
+ (False , False , False ): {
153
+ "nodes" : [fetch_node , parse_node , generate_answer_node ],
154
+ "edges" : [(fetch_node , parse_node ), (parse_node , generate_answer_node )]
155
+ },
156
+ (False , True , True ): {
157
+ "nodes" : [fetch_node , parse_node , reasoning_node , generate_answer_node , cond_node , regen_node ],
158
+ "edges" : [(fetch_node , parse_node ), (parse_node , reasoning_node ), (reasoning_node , generate_answer_node ),
159
+ (generate_answer_node , cond_node ), (cond_node , regen_node ), (cond_node , None )]
160
+ },
161
+ (True , True , True ): {
162
+ "nodes" : [fetch_node , reasoning_node , generate_answer_node , cond_node , regen_node ],
163
+ "edges" : [(fetch_node , reasoning_node ), (reasoning_node , generate_answer_node ),
164
+ (generate_answer_node , cond_node ), (cond_node , regen_node ), (cond_node , None )]
165
+ },
166
+ (True , False , True ): {
167
+ "nodes" : [fetch_node , generate_answer_node , cond_node , regen_node ],
168
+ "edges" : [(fetch_node , generate_answer_node ), (generate_answer_node , cond_node ),
169
+ (cond_node , regen_node ), (cond_node , None )]
170
+ },
171
+ (False , False , True ): {
172
+ "nodes" : [fetch_node , parse_node , generate_answer_node , cond_node , regen_node ],
173
+ "edges" : [(fetch_node , parse_node ), (parse_node , generate_answer_node ),
174
+ (generate_answer_node , cond_node ), (cond_node , regen_node ), (cond_node , None )]
175
+ }
176
+ }
112
177
113
- if self .config .get ("html_mode" ) is False and self .config .get ("reasoning" ) is True :
114
-
115
- return BaseGraph (
116
- nodes = [
117
- fetch_node ,
118
- parse_node ,
119
- reasoning_node ,
120
- generate_answer_node ,
121
- ],
122
- edges = [
123
- (fetch_node , parse_node ),
124
- (parse_node , reasoning_node ),
125
- (reasoning_node , generate_answer_node )
126
- ],
127
- entry_point = fetch_node ,
128
- graph_name = self .__class__ .__name__
129
- )
130
-
131
- elif self .config .get ("html_mode" ) is True and self .config .get ("reasoning" ) is True :
178
+ # Get the current conditions
179
+ html_mode = self .config .get ("html_mode" , False )
180
+ reasoning = self .config .get ("reasoning" , False )
181
+ reattempt = self .config .get ("reattempt" , False )
132
182
133
- return BaseGraph (
134
- nodes = [
135
- fetch_node ,
136
- reasoning_node ,
137
- generate_answer_node ,
138
- ],
139
- edges = [
140
- (fetch_node , reasoning_node ),
141
- (reasoning_node , generate_answer_node )
142
- ],
143
- entry_point = fetch_node ,
144
- graph_name = self .__class__ .__name__
145
- )
183
+ # Retrieve the appropriate graph configuration
184
+ config = graph_variation_config .get ((html_mode , reasoning , reattempt ))
146
185
147
- elif self . config . get ( "html_mode" ) is True and self . config . get ( "reasoning" ) is False :
186
+ if config :
148
187
return BaseGraph (
149
- nodes = [
150
- fetch_node ,
151
- generate_answer_node ,
152
- ],
153
- edges = [
154
- (fetch_node , generate_answer_node )
155
- ],
188
+ nodes = config ["nodes" ],
189
+ edges = config ["edges" ],
156
190
entry_point = fetch_node ,
157
191
graph_name = self .__class__ .__name__
158
192
)
159
193
194
+ # Default return if no conditions match
160
195
return BaseGraph (
161
- nodes = [
162
- fetch_node ,
163
- parse_node ,
164
- generate_answer_node ,
165
- ],
166
- edges = [
167
- (fetch_node , parse_node ),
168
- (parse_node , generate_answer_node )
169
- ],
170
- entry_point = fetch_node ,
171
- graph_name = self .__class__ .__name__
172
- )
173
-
196
+ nodes = [fetch_node , parse_node , generate_answer_node ],
197
+ edges = [(fetch_node , parse_node ), (parse_node , generate_answer_node )],
198
+ entry_point = fetch_node ,
199
+ graph_name = self .__class__ .__name__
200
+ )
201
+
174
202
def run (self ) -> str :
175
203
"""
176
204
Executes the scraping process and returns the answer to the prompt.
0 commit comments