Skip to content

Commit 52934bf

Browse files
committed
implemented graph_config, fixed smart_scraper and speech graph
1 parent f27e0b4 commit 52934bf

File tree

13 files changed

+158
-116
lines changed

13 files changed

+158
-116
lines changed

examples/graph_examples/smart_scraper_example.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,26 @@
77
from scrapegraphai.graphs import SmartScraperGraph
88

99
load_dotenv()
10-
11-
# Define the configuration for the language model
1210
openai_key = os.getenv("OPENAI_APIKEY")
13-
llm_config = {
14-
"api_key": openai_key,
15-
"model_name": "gpt-3.5-turbo",
16-
}
1711

18-
# Define URL and PROMPT
19-
URL = "https://www.ansa.it/veneto/"
20-
PROMPT = "List me all the news with their description."
12+
# Define the configuration for the graph
13+
graph_config = {
14+
"llm": {
15+
"api_key": openai_key,
16+
"model": "gpt-3.5-turbo",
17+
},
18+
# "embedding_model": {
19+
# "api_key": openai_key,
20+
# "model": "gpt-3.5-turbo",
21+
# },
22+
}
2123

2224
# Create the SmartScraperGraph instance
23-
smart_scraper_graph = SmartScraperGraph(PROMPT, URL, llm_config)
25+
smart_scraper_graph = SmartScraperGraph(
26+
prompt = "List me all the news with their description.",
27+
url = "https://www.ansa.it/veneto/",
28+
config = graph_config
29+
)
2430

2531
answer = smart_scraper_graph.run()
2632
print(answer)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Basic example of scraping pipeline using SpeechSummaryGraph
3+
"""
4+
5+
import os
6+
from dotenv import load_dotenv
7+
from scrapegraphai.graphs import SpeechGraph
8+
9+
load_dotenv()
10+
openai_key = os.getenv("OPENAI_APIKEY")
11+
12+
# Save the audio to a file
13+
file_name = "website_summary.mp3"
14+
curr_dir = os.path.dirname(os.path.realpath(__file__))
15+
output_path = os.path.join(curr_dir, file_name)
16+
17+
# Define the configuration for the graph
18+
graph_config = {
19+
"llm": {
20+
"api_key": openai_key,
21+
"model": "gpt-3.5-turbo",
22+
},
23+
"tts_model": {
24+
"api_key": openai_key,
25+
"model": "tts-1",
26+
"voice": "alloy"
27+
},
28+
"output_path": output_path,
29+
}
30+
31+
speech_graph = SpeechGraph(
32+
prompt = "List me all the projects and generate and audio for me to listen to.",
33+
url = "https://perinim.github.io/projects/",
34+
config = graph_config,
35+
)
36+
37+
final_state = speech_graph.run()
38+
print(final_state.get("answer", "No answer found."))

examples/graph_examples/speech_summary_graph_example.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

scrapegraphai/graphs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
"""
44
from .base_graph import BaseGraph
55
from .smart_scraper_graph import SmartScraperGraph
6-
from .speech_summary_graph import SpeechSummaryGraph
6+
from .speech_graph import SpeechGraph

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ..models import OpenAI
55
from .base_graph import BaseGraph
66
from ..nodes import (
7-
FetchHTMLNode,
7+
FetchNode,
88
ParseNode,
99
RAGNode,
1010
GenerateAnswerNode
@@ -34,17 +34,17 @@ class SmartScraperGraph:
3434
'temperature', and 'streaming'.
3535
"""
3636

37-
def __init__(self, prompt: str, url: str, llm_config: dict):
37+
def __init__(self, prompt: str, url: str, config: dict):
3838
"""
3939
Initializes the SmartScraper with a prompt, URL, and language model configuration.
4040
"""
4141
self.prompt = prompt
4242
self.url = url
43-
self.llm_config = llm_config
44-
self.llm = self._create_llm()
43+
self.config = config
44+
self.llm_model = self._create_llm(config["llm"])
4545
self.graph = self._create_graph()
4646

47-
def _create_llm(self):
47+
def _create_llm(self, llm_config: dict):
4848
"""
4949
Creates an instance of the ChatOpenAI class with the provided language model configuration.
5050
@@ -55,12 +55,11 @@ def _create_llm(self):
5555
ValueError: If 'api_key' is not provided in llm_config.
5656
"""
5757
llm_defaults = {
58-
"model_name": "gpt-3.5-turbo",
5958
"temperature": 0,
6059
"streaming": True
6160
}
6261
# Update defaults with any LLM parameters that were provided
63-
llm_params = {**llm_defaults, **self.llm_config}
62+
llm_params = {**llm_defaults, **llm_config}
6463
# Ensure the api_key is set, raise an error if it's not
6564
if "api_key" not in llm_params:
6665
raise ValueError("LLM configuration must include an 'api_key'.")
@@ -75,24 +74,38 @@ def _create_graph(self):
7574
BaseGraph: An instance of the BaseGraph class.
7675
"""
7776
# define the nodes for the graph
78-
fetch_html_node = FetchHTMLNode("fetch_html")
79-
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
80-
rag_node = RAGNode(self.llm, "rag")
81-
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
77+
fetch_node = FetchNode(
78+
input="url | local_dir",
79+
output=["doc"],
80+
)
81+
parse_node = ParseNode(
82+
input="doc",
83+
output=["parsed_doc"],
84+
)
85+
rag_node = RAGNode(
86+
input="user_prompt & (parsed_doc | doc)",
87+
output=["relevant_chunks"],
88+
model_config={"llm_model": self.llm_model},
89+
)
90+
generate_answer_node = GenerateAnswerNode(
91+
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
92+
output=["answer"],
93+
model_config={"llm_model": self.llm_model},
94+
)
8295

8396
return BaseGraph(
8497
nodes={
85-
fetch_html_node,
86-
parse_document_node,
98+
fetch_node,
99+
parse_node,
87100
rag_node,
88101
generate_answer_node,
89102
},
90103
edges={
91-
(fetch_html_node, parse_document_node),
92-
(parse_document_node, rag_node),
104+
(fetch_node, parse_node),
105+
(parse_node, rag_node),
93106
(rag_node, generate_answer_node)
94107
},
95-
entry_point=fetch_html_node
108+
entry_point=fetch_node
96109
)
97110

98111
def run(self) -> str:
@@ -102,7 +115,7 @@ def run(self) -> str:
102115
Returns:
103116
str: The answer extracted from the web page, corresponding to the given prompt.
104117
"""
105-
inputs = {"user_input": self.prompt, "url": self.url}
118+
inputs = {"user_prompt": self.prompt, "url": self.url}
106119
final_state = self.graph.execute(inputs)
107120

108121
return final_state.get("answer", "No answer found.")

scrapegraphai/graphs/speech_summary_graph.py renamed to scrapegraphai/graphs/speech_graph.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from ..models import OpenAI, OpenAITextToSpeech
66
from .base_graph import BaseGraph
77
from ..nodes import (
8-
FetchHTMLNode,
8+
FetchNode,
99
ParseNode,
1010
RAGNode,
1111
GenerateAnswerNode,
1212
TextToSpeechNode,
1313
)
1414

1515

16-
class SpeechSummaryGraph:
16+
class SpeechGraph:
1717
"""
1818
SpeechSummaryGraph is a tool that automates the process of extracting and summarizing
1919
information from web pages, then converting that summary into spoken word via an MP3 file.
@@ -35,21 +35,18 @@ class SpeechSummaryGraph:
3535
output_path (str): The file path where the generated MP3 should be saved.
3636
"""
3737

38-
def __init__(self, prompt: str, url: str, llm_config: dict,
39-
output_path: str = "website_summary.mp3"):
38+
def __init__(self, prompt: str, url: str, config: dict):
4039
"""
4140
Initializes the SmartScraper with a prompt, URL, and language model configuration.
4241
"""
43-
self.prompt = f"{prompt} - Save the summary in a key called 'summary'."
42+
self.prompt = prompt
4443
self.url = url
45-
self.llm_config = llm_config
46-
self.llm = self._create_llm()
47-
self.output_path = output_path
48-
self.text_to_speech_model = OpenAITextToSpeech(
49-
llm_config, model="tts-1", voice="alloy")
44+
self.llm_model = self._create_llm(config["llm"])
45+
self.output_path = config.get("output_path", "output.mp3")
46+
self.text_to_speech_model = OpenAITextToSpeech(config["tts_model"])
5047
self.graph = self._create_graph()
5148

52-
def _create_llm(self):
49+
def _create_llm(self, llm_config: dict):
5350
"""
5451
Creates an instance of the ChatOpenAI class with the provided language model configuration.
5552
@@ -60,12 +57,11 @@ def _create_llm(self):
6057
ValueError: If 'api_key' is not provided in llm_config.
6158
"""
6259
llm_defaults = {
63-
"model_name": "gpt-3.5-turbo",
6460
"temperature": 0,
6561
"streaming": True
6662
}
6763
# Update defaults with any LLM parameters that were provided
68-
llm_params = {**llm_defaults, **self.llm_config}
64+
llm_params = {**llm_defaults, **llm_config}
6965
# Ensure the api_key is set, raise an error if it's not
7066
if "api_key" not in llm_params:
7167
raise ValueError("LLM configuration must include an 'api_key'.")
@@ -79,28 +75,46 @@ def _create_graph(self):
7975
Returns:
8076
BaseGraph: An instance of the BaseGraph class.
8177
"""
82-
fetch_html_node = FetchHTMLNode("fetch_html")
83-
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
84-
rag_node = RAGNode(self.llm, "rag")
85-
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
78+
# define the nodes for the graph
79+
fetch_node = FetchNode(
80+
input="url | local_dir",
81+
output=["doc"],
82+
)
83+
parse_node = ParseNode(
84+
input="doc",
85+
output=["parsed_doc"],
86+
)
87+
rag_node = RAGNode(
88+
input="user_prompt & (parsed_doc | doc)",
89+
output=["relevant_chunks"],
90+
model_config={"llm_model": self.llm_model},
91+
)
92+
generate_answer_node = GenerateAnswerNode(
93+
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
94+
output=["answer"],
95+
model_config={"llm_model": self.llm_model},
96+
)
8697
text_to_speech_node = TextToSpeechNode(
87-
self.text_to_speech_model, "text_to_speech")
98+
input="answer",
99+
output=["audio"],
100+
model_config={"tts_model": self.text_to_speech_model},
101+
)
88102

89103
return BaseGraph(
90104
nodes={
91-
fetch_html_node,
92-
parse_document_node,
105+
fetch_node,
106+
parse_node,
93107
rag_node,
94108
generate_answer_node,
95109
text_to_speech_node
96110
},
97111
edges={
98-
(fetch_html_node, parse_document_node),
99-
(parse_document_node, rag_node),
112+
(fetch_node, parse_node),
113+
(parse_node, rag_node),
100114
(rag_node, generate_answer_node),
101115
(generate_answer_node, text_to_speech_node)
102116
},
103-
entry_point=fetch_html_node
117+
entry_point=fetch_node
104118
)
105119

106120
def run(self) -> str:
@@ -110,7 +124,7 @@ def run(self) -> str:
110124
Returns:
111125
str: The answer extracted from the web page, corresponding to the given prompt.
112126
"""
113-
inputs = {"user_input": self.prompt, "url": self.url}
127+
inputs = {"user_prompt": self.prompt, "url": self.url}
114128
final_state = self.graph.execute(inputs)
115129

116130
audio = final_state.get("audio", None)

scrapegraphai/models/openai_tts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class OpenAITextToSpeech:
2222
bytes of the generated speech.
2323
"""
2424

25-
def __init__(self, llm_config: dict, model: str = "tts-1", voice: str = "alloy"):
25+
def __init__(self, tts_config: dict):
2626
"""
2727
Initializes an instance of the OpenAITextToSpeech class.
2828
@@ -35,9 +35,9 @@ def __init__(self, llm_config: dict, model: str = "tts-1", voice: str = "alloy")
3535
"""
3636

3737
# convert model_name to model
38-
self.client = OpenAI(api_key=llm_config.get("api_key"))
39-
self.model = model
40-
self.voice = voice
38+
self.client = OpenAI(api_key=tts_config.get("api_key"))
39+
self.model = tts_config.get("model", "tts-1")
40+
self.voice = tts_config.get("voice", "alloy")
4141

4242
def run(self, text):
4343
"""

scrapegraphai/nodes/fetch_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class FetchNode(BaseNode):
3535
to succeed.
3636
"""
3737

38-
def __init__(self, input: str, output: List[str], node_name: str = "FetchNode"):
38+
def __init__(self, input: str, output: List[str], node_name: str = "Fetch"):
3939
"""
4040
Initializes the FetchHTMLNode with a node name and node type.
4141
Arguments:

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class GenerateAnswerNode(BaseNode):
3838
updating the state with the generated answer under the 'answer' key.
3939
"""
4040

41-
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswerNode"):
41+
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswer"):
4242
"""
4343
Initializes the GenerateAnswerNode with a language model client and a node name.
4444
Args:

0 commit comments

Comments
 (0)