diff --git a/examples/groq/smart_scraper_groq_openai.py b/examples/groq/smart_scraper_groq_openai.py index 19f86145..47c42303 100644 --- a/examples/groq/smart_scraper_groq_openai.py +++ b/examples/groq/smart_scraper_groq_openai.py @@ -25,7 +25,7 @@ }, "embeddings": { "api_key": openai_key, - "model": "gpt-3.5-turbo", + "model": "openai", }, "headless": False } diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index b8a9efe9..91e7fcf6 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -5,8 +5,12 @@ from abc import ABC, abstractmethod from typing import Optional -from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock +from langchain_aws.embeddings.bedrock import BedrockEmbeddings +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings + from ..helpers import models_tokens +from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI class AbstractGraph(ABC): @@ -43,7 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None): self.source = source self.config = config self.llm_model = self._create_llm(config["llm"], chat=True) - self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm( + self.embedder_model = self._create_default_embedder( + ) if "embeddings" not in config else self._create_embedder( config["embeddings"]) # Set common configuration parameters @@ -165,6 +170,85 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: else: raise ValueError( "Model provided by the configuration not supported") + + def _create_default_embedder(self) -> object: + """ + Create an embedding model instance based on the chosen llm model. + + Returns: + object: An instance of the embedding model client. + + Raises: + ValueError: If the model is not supported. + """ + + if isinstance(self.llm_model, OpenAI): + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) + elif isinstance(self.llm_model, AzureOpenAIEmbeddings): + return self.llm_model + elif isinstance(self.llm_model, AzureOpenAI): + return AzureOpenAIEmbeddings() + elif isinstance(self.llm_model, Ollama): + # unwrap the kwargs from the model whihc is a dict + params = self.llm_model._lc_kwargs + # remove streaming and temperature + params.pop("streaming", None) + params.pop("temperature", None) + + return OllamaEmbeddings(**params) + elif isinstance(self.llm_model, HuggingFace): + return HuggingFaceHubEmbeddings(model=self.llm_model.model) + elif isinstance(self.llm_model, Bedrock): + return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) + else: + raise ValueError("Embedding Model missing or not supported") + + def _create_embedder(self, embedder_config: dict) -> object: + """ + Create an embedding model instance based on the configuration provided. + + Args: + embedder_config (dict): Configuration parameters for the embedding model. + + Returns: + object: An instance of the embedding model client. + + Raises: + KeyError: If the model is not supported. + """ + + # Instantiate the embedding model based on the model name + if "openai" in embedder_config["model"]: + return OpenAIEmbeddings(api_key=embedder_config["api_key"]) + + elif "azure" in embedder_config["model"]: + return AzureOpenAIEmbeddings() + + elif "ollama" in embedder_config["model"]: + embedder_config["model"] = embedder_config["model"].split("/")[-1] + try: + models_tokens["ollama"][embedder_config["model"]] + except KeyError: + raise KeyError("Model not supported") + return OllamaEmbeddings(**embedder_config) + + elif "hugging_face" in embedder_config["model"]: + try: + models_tokens["hugging_face"][embedder_config["model"]] + except KeyError: + raise KeyError("Model not supported") + return HuggingFaceHubEmbeddings(model=embedder_config["model"]) + + elif "bedrock" in embedder_config["model"]: + embedder_config["model"] = embedder_config["model"].split("/")[-1] + try: + models_tokens["bedrock"][embedder_config["model"]] + except KeyError: + raise KeyError("Model not supported") + return BedrockEmbeddings(client=None, model_id=embedder_config["model"]) + else: + raise ValueError( + "Model provided by the configuration not supported") def get_state(self, key=None) -> dict: """"" diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index 92e7011f..4108a56c 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -87,31 +87,7 @@ def execute(self, state: dict) -> dict: if self.verbose: print("--- (updated chunks metadata) ---") - # check if embedder_model is provided, if not use llm_model - embedding_model = self.embedder_model if self.embedder_model else self.llm_model - - if isinstance(embedding_model, OpenAI): - embeddings = OpenAIEmbeddings( - api_key=embedding_model.openai_api_key) - elif isinstance(embedding_model, AzureOpenAIEmbeddings): - embeddings = embedding_model - elif isinstance(embedding_model, AzureOpenAI): - embeddings = AzureOpenAIEmbeddings() - elif isinstance(embedding_model, Ollama): - # unwrap the kwargs from the model whihc is a dict - params = embedding_model._lc_kwargs - # remove streaming and temperature - params.pop("streaming", None) - params.pop("temperature", None) - - embeddings = OllamaEmbeddings(**params) - elif isinstance(embedding_model, HuggingFace): - embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model) - elif isinstance(embedding_model, Bedrock): - embeddings = BedrockEmbeddings( - client=None, model_id=embedding_model.model_id) - else: - raise ValueError("Embedding Model missing or not supported") + embeddings = self.embedder_model retriever = FAISS.from_documents( chunked_docs, embeddings).as_retriever()