55from ..models import OpenAI , OpenAITextToSpeech
66from .base_graph import BaseGraph
77from ..nodes import (
8- FetchHTMLNode ,
8+ FetchNode ,
99 ParseNode ,
1010 RAGNode ,
1111 GenerateAnswerNode ,
1212 TextToSpeechNode ,
1313)
1414
1515
16- class SpeechSummaryGraph :
16+ class SpeechGraph :
1717 """
1818 SpeechSummaryGraph is a tool that automates the process of extracting and summarizing
1919 information from web pages, then converting that summary into spoken word via an MP3 file.
@@ -35,21 +35,18 @@ class SpeechSummaryGraph:
3535 output_path (str): The file path where the generated MP3 should be saved.
3636 """
3737
38- def __init__ (self , prompt : str , url : str , llm_config : dict ,
39- output_path : str = "website_summary.mp3" ):
38+ def __init__ (self , prompt : str , url : str , config : dict ):
4039 """
4140 Initializes the SmartScraper with a prompt, URL, and language model configuration.
4241 """
43- self .prompt = f" { prompt } - Save the summary in a key called 'summary'."
42+ self .prompt = prompt
4443 self .url = url
45- self .llm_config = llm_config
46- self .llm = self ._create_llm ()
47- self .output_path = output_path
48- self .text_to_speech_model = OpenAITextToSpeech (
49- llm_config , model = "tts-1" , voice = "alloy" )
44+ self .llm_model = self ._create_llm (config ["llm" ])
45+ self .output_path = config .get ("output_path" , "output.mp3" )
46+ self .text_to_speech_model = OpenAITextToSpeech (config ["tts_model" ])
5047 self .graph = self ._create_graph ()
5148
52- def _create_llm (self ):
49+ def _create_llm (self , llm_config : dict ):
5350 """
5451 Creates an instance of the ChatOpenAI class with the provided language model configuration.
5552
@@ -60,12 +57,11 @@ def _create_llm(self):
6057 ValueError: If 'api_key' is not provided in llm_config.
6158 """
6259 llm_defaults = {
63- "model_name" : "gpt-3.5-turbo" ,
6460 "temperature" : 0 ,
6561 "streaming" : True
6662 }
6763 # Update defaults with any LLM parameters that were provided
68- llm_params = {** llm_defaults , ** self . llm_config }
64+ llm_params = {** llm_defaults , ** llm_config }
6965 # Ensure the api_key is set, raise an error if it's not
7066 if "api_key" not in llm_params :
7167 raise ValueError ("LLM configuration must include an 'api_key'." )
@@ -79,28 +75,46 @@ def _create_graph(self):
7975 Returns:
8076 BaseGraph: An instance of the BaseGraph class.
8177 """
82- fetch_html_node = FetchHTMLNode ("fetch_html" )
83- parse_document_node = ParseNode (doc_type = "html" , chunks_size = 4000 , node_name = "parse_document" )
84- rag_node = RAGNode (self .llm , "rag" )
85- generate_answer_node = GenerateAnswerNode (self .llm , "generate_answer" )
78+ # define the nodes for the graph
79+ fetch_node = FetchNode (
80+ input = "url | local_dir" ,
81+ output = ["doc" ],
82+ )
83+ parse_node = ParseNode (
84+ input = "doc" ,
85+ output = ["parsed_doc" ],
86+ )
87+ rag_node = RAGNode (
88+ input = "user_prompt & (parsed_doc | doc)" ,
89+ output = ["relevant_chunks" ],
90+ model_config = {"llm_model" : self .llm_model },
91+ )
92+ generate_answer_node = GenerateAnswerNode (
93+ input = "user_prompt & (relevant_chunks | parsed_doc | doc)" ,
94+ output = ["answer" ],
95+ model_config = {"llm_model" : self .llm_model },
96+ )
8697 text_to_speech_node = TextToSpeechNode (
87- self .text_to_speech_model , "text_to_speech" )
98+ input = "answer" ,
99+ output = ["audio" ],
100+ model_config = {"tts_model" : self .text_to_speech_model },
101+ )
88102
89103 return BaseGraph (
90104 nodes = {
91- fetch_html_node ,
92- parse_document_node ,
105+ fetch_node ,
106+ parse_node ,
93107 rag_node ,
94108 generate_answer_node ,
95109 text_to_speech_node
96110 },
97111 edges = {
98- (fetch_html_node , parse_document_node ),
99- (parse_document_node , rag_node ),
112+ (fetch_node , parse_node ),
113+ (parse_node , rag_node ),
100114 (rag_node , generate_answer_node ),
101115 (generate_answer_node , text_to_speech_node )
102116 },
103- entry_point = fetch_html_node
117+ entry_point = fetch_node
104118 )
105119
106120 def run (self ) -> str :
@@ -110,7 +124,7 @@ def run(self) -> str:
110124 Returns:
111125 str: The answer extracted from the web page, corresponding to the given prompt.
112126 """
113- inputs = {"user_input " : self .prompt , "url" : self .url }
127+ inputs = {"user_prompt " : self .prompt , "url" : self .url }
114128 final_state = self .graph .execute (inputs )
115129
116130 audio = final_state .get ("audio" , None )
0 commit comments