From b86aac2188887642564a34d13d55d0fcff220ec1 Mon Sep 17 00:00:00 2001 From: Shubham Kamboj Date: Thu, 2 May 2024 20:09:23 +0530 Subject: [PATCH] feat: Allow end users to pass model instances for llm and embedding model --- examples/azure/smart_scraper_azure_openai.py | 63 ++++++++++++++++++++ scrapegraphai/graphs/abstract_graph.py | 20 ++++++- scrapegraphai/helpers/models_tokens.py | 5 +- scrapegraphai/nodes/rag_node.py | 3 + 4 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 examples/azure/smart_scraper_azure_openai.py diff --git a/examples/azure/smart_scraper_azure_openai.py b/examples/azure/smart_scraper_azure_openai.py new file mode 100644 index 00000000..bfcd6b92 --- /dev/null +++ b/examples/azure/smart_scraper_azure_openai.py @@ -0,0 +1,63 @@ +""" +Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key +""" + +import os +from dotenv import load_dotenv +from langchain_openai import AzureChatOpenAI +from langchain_openai import AzureOpenAIEmbeddings +from scrapegraphai.graphs import SmartScraperGraph +from scrapegraphai.utils import prettify_exec_info + + +## required environment variable in .env +# AZURE_OPENAI_ENDPOINT +# AZURE_OPENAI_CHAT_DEPLOYMENT_NAME +# MODEL_NAME +# AZURE_OPENAI_API_KEY +# OPENAI_API_TYPE +# AZURE_OPENAI_API_VERSION +# AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME +load_dotenv() + + +# ************************************************ +# Initialize the model instances +# ************************************************ + +llm_model_instance = AzureChatOpenAI( + openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], + azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] +) + +embedder_model_instance = AzureOpenAIEmbeddings( + azure_deployment=os.environ["AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"], + openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], +) + +# ************************************************ +# Create the SmartScraperGraph instance and run it +# ************************************************ + +graph_config = { + "llm": {"model_instance": llm_model_instance}, + "embeddings": {"model_instance": embedder_model_instance} +} + +smart_scraper_graph = SmartScraperGraph( + prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, +time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link", + # also accepts a string with the already downloaded HTML code + source="https://www.hmhco.com/event", + config=graph_config +) + +result = smart_scraper_graph.run() +print(result) + +# ************************************************ +# Get graph execution info +# ************************************************ + +graph_exec_info = smart_scraper_graph.get_execution_info() +print(prettify_exec_info(graph_exec_info)) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 5adf8ba6..e70c9e95 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -19,7 +19,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None): self.prompt = prompt self.source = source self.config = config - self.llm_model = self._create_llm(config["llm"]) + self.llm_model = self._create_llm(config["llm"], chat=True) self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm( config["embeddings"]) @@ -32,7 +32,16 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None): self.final_state = None self.execution_info = None - def _create_llm(self, llm_config: dict): + def _set_model_token(self, llm): + + if 'Azure' in str(type(llm)): + try: + self.model_token = models_tokens["azure"][llm.model_name] + except KeyError: + raise KeyError("Model not supported") + + + def _create_llm(self, llm_config: dict, chat=False) -> object: """ Creates an instance of the language model (OpenAI or Gemini) based on configuration. """ @@ -42,6 +51,12 @@ def _create_llm(self, llm_config: dict): } llm_params = {**llm_defaults, **llm_config} + # If model instance is passed directly instead of the model details + if 'model_instance' in llm_params: + if chat: + self._set_model_token(llm_params['model_instance']) + return llm_params['model_instance'] + # Instantiate the language model based on the model name if "gpt-" in llm_params["model"]: try: @@ -129,3 +144,4 @@ def run(self) -> str: Abstract method to execute the graph and return the result. """ pass + diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py index 6b9ed637..9c8abdef 100644 --- a/scrapegraphai/helpers/models_tokens.py +++ b/scrapegraphai/helpers/models_tokens.py @@ -18,7 +18,9 @@ "gpt-4-32k": 32768, "gpt-4-32k-0613": 32768, }, - + "azure": { + "gpt-3.5-turbo": 4096 + }, "gemini": { "gemini-pro": 128000, }, @@ -45,3 +47,4 @@ "claude3": 200000 } } + diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index d10f50c6..3401ff23 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -92,6 +92,8 @@ def execute(self, state): if isinstance(embedding_model, OpenAI): embeddings = OpenAIEmbeddings( api_key=embedding_model.openai_api_key) + elif isinstance(embedding_model, AzureOpenAIEmbeddings): + embeddings = embedding_model elif isinstance(embedding_model, AzureOpenAI): embeddings = AzureOpenAIEmbeddings() elif isinstance(embedding_model, Ollama): @@ -133,3 +135,4 @@ def execute(self, state): state.update({self.output[0]: compressed_docs}) return state +