Skip to content

Commit

Permalink
feat: Allow end users to pass model instances for llm and embedding m…
Browse files Browse the repository at this point in the history
…odel
  • Loading branch information
shkamboj1 committed May 2, 2024
1 parent 40b2a34 commit b86aac2
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 3 deletions.
63 changes: 63 additions & 0 deletions examples/azure/smart_scraper_azure_openai.py
@@ -0,0 +1,63 @@
"""
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
"""

import os
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain_openai import AzureOpenAIEmbeddings
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info


## required environment variable in .env
# AZURE_OPENAI_ENDPOINT
# AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
# MODEL_NAME
# AZURE_OPENAI_API_KEY
# OPENAI_API_TYPE
# AZURE_OPENAI_API_VERSION
# AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME
load_dotenv()


# ************************************************
# Initialize the model instances
# ************************************************

llm_model_instance = AzureChatOpenAI(
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]
)

embedder_model_instance = AzureOpenAIEmbeddings(
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)

# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************

graph_config = {
"llm": {"model_instance": llm_model_instance},
"embeddings": {"model_instance": embedder_model_instance}
}

smart_scraper_graph = SmartScraperGraph(
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days,
time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
# also accepts a string with the already downloaded HTML code
source="https://www.hmhco.com/event",
config=graph_config
)

result = smart_scraper_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))
20 changes: 18 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Expand Up @@ -19,7 +19,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.prompt = prompt
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
config["embeddings"])

Expand All @@ -32,7 +32,16 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.final_state = None
self.execution_info = None

def _create_llm(self, llm_config: dict):
def _set_model_token(self, llm):

if 'Azure' in str(type(llm)):
try:
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")


def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
Expand All @@ -42,6 +51,12 @@ def _create_llm(self, llm_config: dict):
}
llm_params = {**llm_defaults, **llm_config}

# If model instance is passed directly instead of the model details
if 'model_instance' in llm_params:
if chat:
self._set_model_token(llm_params['model_instance'])
return llm_params['model_instance']

# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
try:
Expand Down Expand Up @@ -129,3 +144,4 @@ def run(self) -> str:
Abstract method to execute the graph and return the result.
"""
pass

5 changes: 4 additions & 1 deletion scrapegraphai/helpers/models_tokens.py
Expand Up @@ -18,7 +18,9 @@
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
},

"azure": {
"gpt-3.5-turbo": 4096
},
"gemini": {
"gemini-pro": 128000,
},
Expand All @@ -45,3 +47,4 @@
"claude3": 200000
}
}

3 changes: 3 additions & 0 deletions scrapegraphai/nodes/rag_node.py
Expand Up @@ -92,6 +92,8 @@ def execute(self, state):
if isinstance(embedding_model, OpenAI):
embeddings = OpenAIEmbeddings(
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
Expand Down Expand Up @@ -133,3 +135,4 @@ def execute(self, state):

state.update({self.output[0]: compressed_docs})
return state

0 comments on commit b86aac2

Please sign in to comment.