diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index aff7289c..a263390d 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -1,16 +1,14 @@ """ AbstractGraph Module """ - from abc import ABC, abstractmethod from typing import Optional - -from langchain_aws.embeddings.bedrock import BedrockEmbeddings -from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings - +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from ..helpers import models_tokens from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude +from langchain_aws.embeddings.bedrock import BedrockEmbeddings +from langchain_google_genai import GoogleGenerativeAIEmbeddings class AbstractGraph(ABC): @@ -69,7 +67,7 @@ def _set_model_token(self, llm): self.model_token = models_tokens["azure"][llm.model_name] except KeyError: raise KeyError("Model not supported") - + elif 'HuggingFaceEndpoint' in str(type(llm)): if 'mistral' in llm.repo_id: try: @@ -229,14 +227,11 @@ def _create_embedder(self, embedder_config: dict) -> object: if 'model_instance' in embedder_config: return embedder_config['model_instance'] - # Instantiate the embedding model based on the model name if "openai" in embedder_config["model"]: return OpenAIEmbeddings(api_key=embedder_config["api_key"]) - elif "azure" in embedder_config["model"]: return AzureOpenAIEmbeddings() - elif "ollama" in embedder_config["model"]: embedder_config["model"] = embedder_config["model"].split("/")[-1] try: @@ -244,14 +239,18 @@ def _create_embedder(self, embedder_config: dict) -> object: except KeyError as exc: raise KeyError("Model not supported") from exc return OllamaEmbeddings(**embedder_config) - elif "hugging_face" in embedder_config["model"]: try: models_tokens["hugging_face"][embedder_config["model"]] except KeyError as exc: raise KeyError("Model not supported")from exc return HuggingFaceHubEmbeddings(model=embedder_config["model"]) - + elif "gemini" in embedder_config["model"]: + try: + models_tokens["gemini"][embedder_config["model"]] + except KeyError as exc: + raise KeyError("Model not supported")from exc + return GoogleGenerativeAIEmbeddings(model=embedder_config["model"]) elif "bedrock" in embedder_config["model"]: embedder_config["model"] = embedder_config["model"].split("/")[-1] try: