diff --git a/examples/openai/custom_graph_openai.py b/examples/openai/custom_graph_openai.py index be5a4d55..175c51ab 100644 --- a/examples/openai/custom_graph_openai.py +++ b/examples/openai/custom_graph_openai.py @@ -62,19 +62,19 @@ # ************************************************ graph = BaseGraph( - nodes={ + nodes=[ robot_node, fetch_node, parse_node, rag_node, generate_answer_node, - }, - edges={ + ], + edges=[ (robot_node, fetch_node), (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) - }, + ], entry_point=robot_node ) diff --git a/manual deployment/commit_and_push_with_tests.sh b/manual deployment/commit_and_push_with_tests.sh index 9cf7c1af..d97fe67f 100755 --- a/manual deployment/commit_and_push_with_tests.sh +++ b/manual deployment/commit_and_push_with_tests.sh @@ -13,6 +13,8 @@ pylint pylint scrapegraphai/**/*.py scrapegraphai/*.py tests/**/*.py cd tests +poetry install + # Run pytest if ! pytest; then echo "Pytest failed. Aborting commit and push." diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 8df92b9a..c2ebfb0b 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -2,6 +2,7 @@ Module for creating the base graphs """ import time +import warnings from langchain_community.callbacks import get_openai_callback @@ -10,31 +11,37 @@ class BaseGraph: BaseGraph manages the execution flow of a graph composed of interconnected nodes. Attributes: - nodes (dict): A dictionary mapping each node's name to its corresponding node instance. - edges (dict): A dictionary representing the directed edges of the graph where each + nodes (list): A dictionary mapping each node's name to its corresponding node instance. + edges (list): A dictionary representing the directed edges of the graph where each key-value pair corresponds to the from-node and to-node relationship. entry_point (str): The name of the entry point node from which the graph execution begins. Methods: - execute(initial_state): Executes the graph's nodes starting from the entry point and + execute(initial_state): Executes the graph's nodes starting from the entry point and traverses the graph based on the provided initial state. Args: nodes (iterable): An iterable of node instances that will be part of the graph. - edges (iterable): An iterable of tuples where each tuple represents a directed edge + edges (iterable): An iterable of tuples where each tuple represents a directed edge in the graph, defined by a pair of nodes (from_node, to_node). entry_point (BaseNode): The node instance that represents the entry point of the graph. """ - def __init__(self, nodes: dict, edges: dict, entry_point: str): + def __init__(self, nodes: list, edges: list, entry_point: str): """ Initializes the graph with nodes, edges, and the entry point. """ - self.nodes = {node.node_name: node for node in nodes} - self.edges = self._create_edges(edges) + + self.nodes = nodes + self.edges = self._create_edges({e for e in edges}) self.entry_point = entry_point.node_name - def _create_edges(self, edges: dict) -> dict: + if nodes[0].node_name != entry_point.node_name: + # raise a warning if the entry point is not the first node in the list + warnings.warn( + "Careful! The entry point node is different from the first node if the graph.") + + def _create_edges(self, edges: list) -> dict: """ Helper method to create a dictionary of edges from the given iterable of tuples. @@ -51,8 +58,8 @@ def _create_edges(self, edges: dict) -> dict: def execute(self, initial_state: dict) -> dict: """ - Executes the graph by traversing nodes starting from the entry point. The execution - follows the edges based on the result of each node's execution and continues until + Executes the graph by traversing nodes starting from the entry point. The execution + follows the edges based on the result of each node's execution and continues until it reaches a node with no outgoing edges. Args: @@ -61,7 +68,8 @@ def execute(self, initial_state: dict) -> dict: Returns: dict: The state after execution has completed, which may have been altered by the nodes. """ - current_node_name = self.entry_point + print(self.nodes) + current_node_name = self.nodes[0] state = initial_state # variables for tracking execution info @@ -75,10 +83,10 @@ def execute(self, initial_state: dict) -> dict: "total_cost_USD": 0.0, } - while current_node_name is not None: + for index in self.nodes: curr_time = time.time() - current_node = self.nodes[current_node_name] + current_node = index with get_openai_callback() as cb: result = current_node.execute(state) diff --git a/scrapegraphai/graphs/script_creator_graph.py b/scrapegraphai/graphs/script_creator_graph.py index 06cc7a81..fa86eeb4 100644 --- a/scrapegraphai/graphs/script_creator_graph.py +++ b/scrapegraphai/graphs/script_creator_graph.py @@ -1,4 +1,4 @@ -""" +""" Module for creating the smart scraper """ from .base_graph import BaseGraph @@ -57,17 +57,17 @@ def _create_graph(self): ) return BaseGraph( - nodes={ + nodes=[ fetch_node, parse_node, rag_node, generate_scraper_node, - }, - edges={ + ], + edges=[ (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_scraper_node) - }, + ], entry_point=fetch_node ) diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index ad21e485..b48965dd 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -11,6 +11,7 @@ ) from .abstract_graph import AbstractGraph + class SearchGraph(AbstractGraph): """ Module for searching info on the internet @@ -49,19 +50,19 @@ def _create_graph(self): ) return BaseGraph( - nodes={ + nodes=[ search_internet_node, fetch_node, parse_node, rag_node, generate_answer_node, - }, - edges={ + ], + edges=[ (search_internet_node, fetch_node), (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) - }, + ], entry_point=search_internet_node ) diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index e413727b..5a520224 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -1,4 +1,4 @@ -""" +""" Module for creating the smart scraper """ from .base_graph import BaseGraph @@ -10,6 +10,7 @@ ) from .abstract_graph import AbstractGraph + class SmartScraperGraph(AbstractGraph): """ SmartScraper is a comprehensive web scraping tool that automates the process of extracting @@ -52,17 +53,17 @@ def _create_graph(self): ) return BaseGraph( - nodes={ + nodes=[ fetch_node, parse_node, rag_node, generate_answer_node, - }, - edges={ + ], + edges=[ (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) - }, + ], entry_point=fetch_node ) @@ -70,7 +71,7 @@ def run(self) -> str: """ Executes the web scraping process and returns the answer to the prompt. """ - inputs = {"user_prompt": self.prompt, self.input_key: self.source} + inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) return self.final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/graphs/speech_graph.py b/scrapegraphai/graphs/speech_graph.py index f050acb4..2b10077f 100644 --- a/scrapegraphai/graphs/speech_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -62,19 +62,19 @@ def _create_graph(self): ) return BaseGraph( - nodes={ + nodes=[ fetch_node, parse_node, rag_node, generate_answer_node, text_to_speech_node - }, - edges={ + ], + edges=[ (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node), (generate_answer_node, text_to_speech_node) - }, + ], entry_point=fetch_node )