Skip to content
Merged
6 changes: 3 additions & 3 deletions examples/openai/custom_graph_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
robot_node = RobotsNode(
input="url",
output=["is_scrapable"],
node_config={"llm": llm_model}
node_config={"llm_model": llm_model}
)

fetch_node = FetchNode(
Expand All @@ -50,12 +50,12 @@
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={"llm": llm_model},
node_config={"llm_model": llm_model},
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={"llm": llm_model},
node_config={"llm_model": llm_model},
)

# ************************************************
Expand Down
2 changes: 1 addition & 1 deletion examples/single_node/robot_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
robots_node = RobotsNode(
input="url",
output=["is_scrapable"],
node_config={"llm": llm_model,
node_config={"llm_model": llm_model,
"headless": False
}
)
Expand Down
25 changes: 21 additions & 4 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,32 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])

# Create the graph
self.graph = self._create_graph()
self.final_state = None
self.execution_info = None

# Set common configuration parameters
self.verbose = True if config is None else config.get("verbose", False)
self.headless = True if config is None else config.get(
"headless", True)
common_params = {"headless": self.headless,
"verbose": self.verbose,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model}
self.set_common_params(common_params, overwrite=False)

# Create the graph
self.graph = self._create_graph()
self.final_state = None
self.execution_info = None

def set_common_params(self, params: dict, overwrite=False):
"""
Pass parameters to every node in the graph unless otherwise defined in the graph.

Args:
params (dict): Common parameters and their values.
"""

for node in self.graph.nodes:
node.update_config(params, overwrite)

def _set_model_token(self, llm):

Expand Down
13 changes: 3 additions & 10 deletions scrapegraphai/graphs/csv_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,27 @@ def _create_graph(self):
fetch_node = FetchNode(
input="csv_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
}
)
generate_answer_node = GenerateAnswerCSVNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model,
}
)

Expand All @@ -85,4 +78,4 @@ def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")
17 changes: 5 additions & 12 deletions scrapegraphai/graphs/json_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,27 @@ def _create_graph(self) -> BaseGraph:
fetch_node = FetchNode(
input="json_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)

Expand Down Expand Up @@ -113,4 +106,4 @@ def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")
15 changes: 4 additions & 11 deletions scrapegraphai/graphs/script_creator_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,25 @@ def _create_graph(self) -> BaseGraph:
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": self.model_token,
"verbose": self.verbose
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_scraper_node = GenerateScraperNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={"llm": self.llm_model,
"verbose": self.verbose},
node_config={"llm_model": self.llm_model},
library=self.library,
website=self.source
)
Expand Down Expand Up @@ -117,4 +110,4 @@ def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")
22 changes: 7 additions & 15 deletions scrapegraphai/graphs/search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,33 @@ def _create_graph(self) -> BaseGraph:
input="user_prompt",
output=["url"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)

Expand Down Expand Up @@ -116,4 +108,4 @@ def run(self) -> str:
inputs = {"user_prompt": self.prompt}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")
19 changes: 6 additions & 13 deletions scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,28 @@ def _create_graph(self) -> BaseGraph:
"""
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)

Expand Down Expand Up @@ -115,4 +108,4 @@ def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")
23 changes: 7 additions & 16 deletions scrapegraphai/graphs/speech_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,43 +56,34 @@ def _create_graph(self) -> BaseGraph:

fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
}
"llm_model": self.llm_model,
"embedder_model": self.embedder_model }
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
text_to_speech_node = TextToSpeechNode(
input="answer",
output=["audio"],
node_config={
"tts_model": OpenAITextToSpeech(self.config["tts_model"]),
"verbose": self.verbose
"tts_model": OpenAITextToSpeech(self.config["tts_model"])
}
)

Expand Down Expand Up @@ -131,4 +122,4 @@ def run(self) -> str:
"output_path", "output.mp3"))
print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}")

return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")
Loading