-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathbase_node.py
237 lines (190 loc) · 8 KB
/
base_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
This module defines the base node class for the ScrapeGraphAI application.
"""
import re
from abc import ABC, abstractmethod
from typing import List, Optional
from ..utils import get_logger
class BaseNode(ABC):
"""
An abstract base class for nodes in a graph-based workflow,
designed to perform specific actions when executed.
Attributes:
node_name (str): The unique identifier name for the node.
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of
min_input_len (int): Minimum required number of input keys.
node_config (Optional[dict]): Additional configuration for the node.
logger (logging.Logger): The centralized root logger
Args:
node_name (str): Name for identifying the node.
node_type (str): Type of the node; must be 'node' or 'conditional_node'.
input (str): Expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
min_input_len (int, optional): Minimum required number of input keys; defaults to 1.
node_config (Optional[dict], optional): Additional configuration
for the node; defaults to None.
Raises:
ValueError: If `node_type` is not one of the allowed types.
Example:
>>> class MyNode(BaseNode):
... def execute(self, state):
... # Implementation of node logic here
... return state
...
>>> my_node = MyNode("ExampleNode", "node", "input_spec", ["output_spec"])
>>> updated_state = my_node.execute({'key': 'value'})
{'key': 'value'}
"""
def __init__(
self,
node_name: str,
node_type: str,
input: str,
output: List[str],
min_input_len: int = 1,
node_config: Optional[dict] = None,
):
self.node_name = node_name
self.input = input
self.output = output
self.min_input_len = min_input_len
self.node_config = node_config
self.logger = get_logger()
if node_type not in ["node", "conditional_node"]:
raise ValueError(
f"node_type must be 'node' or 'conditional_node', got '{node_type}'"
)
self.node_type = node_type
@abstractmethod
def execute(self, state: dict) -> dict:
"""
Execute the node's logic based on the current state and update it accordingly.
Args:
state (dict): The current state of the graph.
Returns:
dict: The updated state after executing the node's logic.
"""
pass
def update_config(self, params: dict, overwrite: bool = False):
"""
Updates the node_config dictionary as well as attributes with same key.
Args:
param (dict): The dictionary to update node_config with.
overwrite (bool): Flag indicating if the values of node_config
should be overwritten if their value is not None.
"""
for key, val in params.items():
if hasattr(self, key) and not overwrite:
continue
setattr(self, key, val)
def get_input_keys(self, state: dict) -> List[str]:
"""
Determines the necessary state keys based on the input specification.
Args:
state (dict): The current state of the graph used to parse input keys.
Returns:
List[str]: A list of input keys required for node operation.
Raises:
ValueError: If error occurs in parsing input keys.
"""
try:
input_keys = self._parse_input_keys(state, self.input)
self._validate_input_keys(input_keys)
return input_keys
except ValueError as e:
raise ValueError(f"Error parsing input keys for {self.node_name}") from e
def _validate_input_keys(self, input_keys):
"""
Validates if the provided input keys meet the minimum length requirement.
Args:
input_keys (List[str]): The list of input keys to validate.
Raises:
ValueError: If the number of input keys is less than the minimum required.
"""
if len(input_keys) < self.min_input_len:
raise ValueError(
f"""{self.node_name} requires at least {self.min_input_len} input keys,
got {len(input_keys)}."""
)
def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
"""
Parses the input keys expression to extract
relevant keys from the state based on logical conditions.
The expression can contain AND (&), OR (|), and parentheses to group conditions.
Args:
state (dict): The current state of the graph.
expression (str): The input keys expression to parse.
Returns:
List[str]: A list of key names that match the input keys expression logic.
Raises:
ValueError: If the expression is invalid or if no state keys match the expression.
"""
if not expression:
raise ValueError("Empty expression.")
pattern = (
r"\b("
+ "|".join(re.escape(key) for key in state.keys())
+ r")(\b\s*\b)("
+ "|".join(re.escape(key) for key in state.keys())
+ r")\b"
)
if re.search(pattern, expression):
raise ValueError(
"Adjacent state keys found without an operator between them."
)
expression = expression.replace(" ", "")
if (
expression[0] in "&|"
or expression[-1] in "&|"
or "&&" in expression
or "||" in expression
or "&|" in expression
or "|&" in expression
):
raise ValueError("Invalid operator usage.")
open_parentheses = close_parentheses = 0
for i, char in enumerate(expression):
if char == "(":
open_parentheses += 1
elif char == ")":
close_parentheses += 1
# Check for invalid operator sequences
if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|":
raise ValueError(
"Invalid operator placement: operators cannot be adjacent."
)
if open_parentheses != close_parentheses:
raise ValueError("Missing or unbalanced parentheses in expression.")
def evaluate_simple_expression(exp: str) -> List[str]:
"""Evaluate an expression without parentheses."""
for or_segment in exp.split("|"):
and_segment = or_segment.split("&")
if all(elem.strip() in state for elem in and_segment):
return [
elem.strip() for elem in and_segment if elem.strip() in state
]
return []
def evaluate_expression(expression: str) -> List[str]:
"""Evaluate an expression with parentheses."""
while "(" in expression:
start = expression.rfind("(")
end = expression.find(")", start)
sub_exp = expression[start + 1 : end]
sub_result = evaluate_simple_expression(sub_exp)
expression = (
expression[:start] + "|".join(sub_result) + expression[end + 1 :]
)
return evaluate_simple_expression(expression)
result = evaluate_expression(expression)
if not result:
raise ValueError(
f"""No state keys matched the expression.
Expression was {expression}.
State contains keys: {', '.join(state.keys())}"""
)
final_result = []
for key in result:
if key not in final_result:
final_result.append(key)
return final_result