diff --git a/examples/openai/custom_graph_openai.py b/examples/openai/custom_graph_openai.py index dab82c1f..5744669b 100644 --- a/examples/openai/custom_graph_openai.py +++ b/examples/openai/custom_graph_openai.py @@ -34,7 +34,7 @@ robot_node = RobotsNode( input="url", output=["is_scrapable"], - node_config={"llm": llm_model} + node_config={"llm_model": llm_model} ) fetch_node = FetchNode( @@ -50,12 +50,12 @@ rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], - node_config={"llm": llm_model}, + node_config={"llm_model": llm_model}, ) generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], - node_config={"llm": llm_model}, + node_config={"llm_model": llm_model}, ) # ************************************************ diff --git a/examples/single_node/robot_node.py b/examples/single_node/robot_node.py index 0e446262..257c4efb 100644 --- a/examples/single_node/robot_node.py +++ b/examples/single_node/robot_node.py @@ -26,7 +26,7 @@ robots_node = RobotsNode( input="url", output=["is_scrapable"], - node_config={"llm": llm_model, + node_config={"llm_model": llm_model, "headless": False } ) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index aff7289c..9bafa019 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -52,15 +52,32 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None): ) if "embeddings" not in config else self._create_embedder( config["embeddings"]) + # Create the graph + self.graph = self._create_graph() + self.final_state = None + self.execution_info = None + # Set common configuration parameters self.verbose = True if config is None else config.get("verbose", False) self.headless = True if config is None else config.get( "headless", True) + common_params = {"headless": self.headless, + "verbose": self.verbose, + "llm_model": self.llm_model, + "embedder_model": self.embedder_model} + self.set_common_params(common_params, overwrite=False) - # Create the graph - self.graph = self._create_graph() - self.final_state = None - self.execution_info = None + + def set_common_params(self, params: dict, overwrite=False): + """ + Pass parameters to every node in the graph unless otherwise defined in the graph. + + Args: + params (dict): Common parameters and their values. + """ + + for node in self.graph.nodes: + node.update_config(params, overwrite) def _set_model_token(self, llm): diff --git a/scrapegraphai/graphs/csv_scraper_graph.py b/scrapegraphai/graphs/csv_scraper_graph.py index 9a5eb931..24c19234 100644 --- a/scrapegraphai/graphs/csv_scraper_graph.py +++ b/scrapegraphai/graphs/csv_scraper_graph.py @@ -32,34 +32,27 @@ def _create_graph(self): fetch_node = FetchNode( input="csv_dir", output=["doc"], - node_config={ - "headless": self.headless, - "verbose": self.verbose - } ) parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={ "chunk_size": self.model_token, - "verbose": self.verbose } ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ - "llm": self.llm_model, + "llm_model": self.llm_model, "embedder_model": self.embedder_model, - "verbose": self.verbose } ) generate_answer_node = GenerateAnswerCSVNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ - "llm": self.llm_model, - "verbose": self.verbose + "llm_model": self.llm_model, } ) @@ -85,4 +78,4 @@ def run(self) -> str: inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) - return self.final_state.get("answer", "No answer found.") + return self.final_state.get("answer", "No answer found.") \ No newline at end of file diff --git a/scrapegraphai/graphs/json_scraper_graph.py b/scrapegraphai/graphs/json_scraper_graph.py index f7392212..aec41195 100644 --- a/scrapegraphai/graphs/json_scraper_graph.py +++ b/scrapegraphai/graphs/json_scraper_graph.py @@ -56,34 +56,27 @@ def _create_graph(self) -> BaseGraph: fetch_node = FetchNode( input="json_dir", output=["doc"], - node_config={ - "headless": self.headless, - "verbose": self.verbose - } ) parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={ - "chunk_size": self.model_token, - "verbose": self.verbose + "chunk_size": self.model_token } ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ - "llm": self.llm_model, - "embedder_model": self.embedder_model, - "verbose": self.verbose + "llm_model": self.llm_model, + "embedder_model": self.embedder_model } ) generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ - "llm": self.llm_model, - "verbose": self.verbose + "llm_model": self.llm_model } ) @@ -113,4 +106,4 @@ def run(self) -> str: inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) - return self.final_state.get("answer", "No answer found.") + return self.final_state.get("answer", "No answer found.") \ No newline at end of file diff --git a/scrapegraphai/graphs/script_creator_graph.py b/scrapegraphai/graphs/script_creator_graph.py index 105048db..5ffc358b 100644 --- a/scrapegraphai/graphs/script_creator_graph.py +++ b/scrapegraphai/graphs/script_creator_graph.py @@ -61,32 +61,25 @@ def _create_graph(self) -> BaseGraph: fetch_node = FetchNode( input="url | local_dir", output=["doc"], - node_config={ - "headless": self.headless, - "verbose": self.verbose - } ) parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={"chunk_size": self.model_token, - "verbose": self.verbose } ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ - "llm": self.llm_model, - "embedder_model": self.embedder_model, - "verbose": self.verbose + "llm_model": self.llm_model, + "embedder_model": self.embedder_model } ) generate_scraper_node = GenerateScraperNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], - node_config={"llm": self.llm_model, - "verbose": self.verbose}, + node_config={"llm_model": self.llm_model}, library=self.library, website=self.source ) @@ -117,4 +110,4 @@ def run(self) -> str: inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) - return self.final_state.get("answer", "No answer found.") + return self.final_state.get("answer", "No answer found.") \ No newline at end of file diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index 41548a77..9c463e1a 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -50,41 +50,33 @@ def _create_graph(self) -> BaseGraph: input="user_prompt", output=["url"], node_config={ - "llm": self.llm_model, - "verbose": self.verbose + "llm_model": self.llm_model } ) fetch_node = FetchNode( input="url | local_dir", - output=["doc"], - node_config={ - "headless": self.headless, - "verbose": self.verbose - } + output=["doc"] ) parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={ - "chunk_size": self.model_token, - "verbose": self.verbose + "chunk_size": self.model_token } ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ - "llm": self.llm_model, - "embedder_model": self.embedder_model, - "verbose": self.verbose + "llm_model": self.llm_model, + "embedder_model": self.embedder_model } ) generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ - "llm": self.llm_model, - "verbose": self.verbose + "llm_model": self.llm_model } ) @@ -116,4 +108,4 @@ def run(self) -> str: inputs = {"user_prompt": self.prompt} self.final_state, self.execution_info = self.graph.execute(inputs) - return self.final_state.get("answer", "No answer found.") + return self.final_state.get("answer", "No answer found.") \ No newline at end of file diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 4d6b0e93..a9e63823 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -57,35 +57,28 @@ def _create_graph(self) -> BaseGraph: """ fetch_node = FetchNode( input="url | local_dir", - output=["doc"], - node_config={ - "headless": self.headless, - "verbose": self.verbose - } + output=["doc"] ) parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={ - "chunk_size": self.model_token, - "verbose": self.verbose + "chunk_size": self.model_token } ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ - "llm": self.llm_model, - "embedder_model": self.embedder_model, - "verbose": self.verbose + "llm_model": self.llm_model, + "embedder_model": self.embedder_model } ) generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ - "llm": self.llm_model, - "verbose": self.verbose + "llm_model": self.llm_model } ) @@ -115,4 +108,4 @@ def run(self) -> str: inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) - return self.final_state.get("answer", "No answer found.") + return self.final_state.get("answer", "No answer found.") \ No newline at end of file diff --git a/scrapegraphai/graphs/speech_graph.py b/scrapegraphai/graphs/speech_graph.py index 3edadfd0..3ca2b703 100644 --- a/scrapegraphai/graphs/speech_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -56,43 +56,34 @@ def _create_graph(self) -> BaseGraph: fetch_node = FetchNode( input="url | local_dir", - output=["doc"], - node_config={ - "headless": self.headless, - "verbose": self.verbose - } + output=["doc"] ) parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={ - "chunk_size": self.model_token, - "verbose": self.verbose + "chunk_size": self.model_token } ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ - "llm": self.llm_model, - "embedder_model": self.embedder_model, - "verbose": self.verbose - } + "llm_model": self.llm_model, + "embedder_model": self.embedder_model } ) generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ - "llm": self.llm_model, - "verbose": self.verbose + "llm_model": self.llm_model } ) text_to_speech_node = TextToSpeechNode( input="answer", output=["audio"], node_config={ - "tts_model": OpenAITextToSpeech(self.config["tts_model"]), - "verbose": self.verbose + "tts_model": OpenAITextToSpeech(self.config["tts_model"]) } ) @@ -131,4 +122,4 @@ def run(self) -> str: "output_path", "output.mp3")) print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}") - return self.final_state.get("answer", "No answer found.") + return self.final_state.get("answer", "No answer found.") \ No newline at end of file diff --git a/scrapegraphai/graphs/xml_scraper_graph.py b/scrapegraphai/graphs/xml_scraper_graph.py index c84e1506..945dc165 100644 --- a/scrapegraphai/graphs/xml_scraper_graph.py +++ b/scrapegraphai/graphs/xml_scraper_graph.py @@ -57,35 +57,28 @@ def _create_graph(self) -> BaseGraph: fetch_node = FetchNode( input="xml_dir", - output=["doc"], - node_config={ - "headless": self.headless, - "verbose": self.verbose - } + output=["doc"] ) parse_node = ParseNode( input="doc", output=["parsed_doc"], node_config={ - "chunk_size": self.model_token, - "verbose": self.verbose + "chunk_size": self.model_token } ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ - "llm": self.llm_model, - "embedder_model": self.embedder_model, - "verbose": self.verbose + "llm_model": self.llm_model, + "embedder_model": self.embedder_model } ) generate_answer_node = GenerateAnswerNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ - "llm": self.llm_model, - "verbose": self.verbose + "llm_model": self.llm_model } ) @@ -115,4 +108,4 @@ def run(self) -> str: inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) - return self.final_state.get("answer", "No answer found.") + return self.final_state.get("answer", "No answer found.") \ No newline at end of file diff --git a/scrapegraphai/nodes/base_node.py b/scrapegraphai/nodes/base_node.py index f3329320..cabfeda0 100644 --- a/scrapegraphai/nodes/base_node.py +++ b/scrapegraphai/nodes/base_node.py @@ -68,6 +68,21 @@ def execute(self, state: dict) -> dict: pass + def update_config(self, params: dict, overwrite: bool = False): + """ + Updates the node_config dictionary as well as attributes with same key. + + Args: + param (dict): The dictionary to update node_config with. + overwrite (bool): Flag indicating if the values of node_config should be overwritten if their value is not None. + """ + if self.node_config is None: + self.node_config = {} + for key, val in params.items(): + if hasattr(self, key) and (key not in self.node_config or overwrite): + self.node_config[key] = val + setattr(self, key, val) + def get_input_keys(self, state: dict) -> List[str]: """ Determines the necessary state keys based on the input specification. diff --git a/scrapegraphai/nodes/fetch_node.py b/scrapegraphai/nodes/fetch_node.py index f873654d..82d67949 100644 --- a/scrapegraphai/nodes/fetch_node.py +++ b/scrapegraphai/nodes/fetch_node.py @@ -29,7 +29,7 @@ class FetchNode(BaseNode): node_name (str): The unique identifier name for the node, defaulting to "Fetch". """ - def __init__(self, input: str, output: List[str], node_config: Optional[dict], node_name: str = "Fetch"): + def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "Fetch"): super().__init__(node_name, "node", input, output, 1) self.headless = True if node_config is None else node_config.get("headless", True) diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index ac861816..6d2b84fc 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -22,14 +22,14 @@ class GenerateAnswerCSVNode(BaseNode): an answer. Attributes: - llm: An instance of a language model client, configured for generating answers. + llm_model: An instance of a language model client, configured for generating answers. node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswerNodeCsv". node_type (str): The type of the node, set to "node" indicating a standard operational node. Args: - llm: An instance of the language model client (e.g., ChatOpenAI) used + llm_model: An instance of the language model client (e.g., ChatOpenAI) used for generating answers. node_name (str, optional): The unique identifier name for the node. Defaults to "GenerateAnswerNodeCsv". @@ -44,11 +44,11 @@ def __init__(self, input: str, output: List[str], node_config: dict, """ Initializes the GenerateAnswerNodeCsv with a language model client and a node name. Args: - llm: An instance of the OpenAIImageToText class. + llm_model: An instance of the OpenAIImageToText class. node_name (str): name of the node """ super().__init__(node_name, "node", input, output, 2, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.verbose = True if node_config is None else node_config.get( "verbose", False) diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index e9b4dd40..df3078ef 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -37,7 +37,7 @@ def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "GenerateAnswer"): super().__init__(node_name, "node", input, output, 2, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.verbose = True if node_config is None else node_config.get("verbose", False) def execute(self, state: dict) -> dict: diff --git a/scrapegraphai/nodes/generate_answer_node_csv.py b/scrapegraphai/nodes/generate_answer_node_csv.py index ac861816..6d2b84fc 100644 --- a/scrapegraphai/nodes/generate_answer_node_csv.py +++ b/scrapegraphai/nodes/generate_answer_node_csv.py @@ -22,14 +22,14 @@ class GenerateAnswerCSVNode(BaseNode): an answer. Attributes: - llm: An instance of a language model client, configured for generating answers. + llm_model: An instance of a language model client, configured for generating answers. node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswerNodeCsv". node_type (str): The type of the node, set to "node" indicating a standard operational node. Args: - llm: An instance of the language model client (e.g., ChatOpenAI) used + llm_model: An instance of the language model client (e.g., ChatOpenAI) used for generating answers. node_name (str, optional): The unique identifier name for the node. Defaults to "GenerateAnswerNodeCsv". @@ -44,11 +44,11 @@ def __init__(self, input: str, output: List[str], node_config: dict, """ Initializes the GenerateAnswerNodeCsv with a language model client and a node name. Args: - llm: An instance of the OpenAIImageToText class. + llm_model: An instance of the OpenAIImageToText class. node_name (str): name of the node """ super().__init__(node_name, "node", input, output, 2, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.verbose = True if node_config is None else node_config.get( "verbose", False) diff --git a/scrapegraphai/nodes/generate_scraper_node.py b/scrapegraphai/nodes/generate_scraper_node.py index 9c80fc19..2e1f959e 100644 --- a/scrapegraphai/nodes/generate_scraper_node.py +++ b/scrapegraphai/nodes/generate_scraper_node.py @@ -40,7 +40,7 @@ def __init__(self, input: str, output: List[str], node_config: dict, library: str, website: str, node_name: str = "GenerateAnswer"): super().__init__(node_name, "node", input, output, 2, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.library = library self.source = website diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index b883845a..8c692ec8 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -34,7 +34,7 @@ class RAGNode(BaseNode): def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "RAG"): super().__init__(node_name, "node", input, output, 2, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.embedder_model = node_config.get("embedder_model", None) self.verbose = True if node_config is None else node_config.get( "verbose", False) diff --git a/scrapegraphai/nodes/robots_node.py b/scrapegraphai/nodes/robots_node.py index 001de62d..8c341183 100644 --- a/scrapegraphai/nodes/robots_node.py +++ b/scrapegraphai/nodes/robots_node.py @@ -38,7 +38,7 @@ def __init__(self, input: str, output: List[str], node_config: dict, force_scra node_name: str = "Robots"): super().__init__(node_name, "node", input, output, 1) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.force_scraping = force_scraping self.verbose = True if node_config is None else node_config.get("verbose", False) diff --git a/scrapegraphai/nodes/search_internet_node.py b/scrapegraphai/nodes/search_internet_node.py index 00cf9211..01095ef8 100644 --- a/scrapegraphai/nodes/search_internet_node.py +++ b/scrapegraphai/nodes/search_internet_node.py @@ -31,7 +31,7 @@ def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "SearchInternet"): super().__init__(node_name, "node", input, output, 1, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.verbose = True if node_config is None else node_config.get("verbose", False) def execute(self, state: dict) -> dict: diff --git a/scrapegraphai/nodes/search_link_node.py b/scrapegraphai/nodes/search_link_node.py index 7f766b5b..037b862e 100644 --- a/scrapegraphai/nodes/search_link_node.py +++ b/scrapegraphai/nodes/search_link_node.py @@ -37,7 +37,7 @@ def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "GenerateLinks"): super().__init__(node_name, "node", input, output, 1, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.verbose = True if node_config is None else node_config.get("verbose", False) def execute(self, state: dict) -> dict: