diff --git a/examples/huggingfacehub/smart_scraper_huggingfacehub.py b/examples/huggingfacehub/smart_scraper_huggingfacehub.py new file mode 100644 index 00000000..082ce59c --- /dev/null +++ b/examples/huggingfacehub/smart_scraper_huggingfacehub.py @@ -0,0 +1,63 @@ +""" +Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key +""" + +import os +from dotenv import load_dotenv +from scrapegraphai.graphs import SmartScraperGraph +from scrapegraphai.utils import prettify_exec_info +from langchain_community.llms import HuggingFaceEndpoint +from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings + + + + +## required environment variable in .env +#HUGGINGFACEHUB_API_TOKEN +load_dotenv() + +HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') +# ************************************************ +# Initialize the model instances +# ************************************************ + +repo_id = "mistralai/Mistral-7B-Instruct-v0.2" + +llm_model_instance = HuggingFaceEndpoint( + repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN +) + + + + +embedder_model_instance = HuggingFaceInferenceAPIEmbeddings( + api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2" +) + +# ************************************************ +# Create the SmartScraperGraph instance and run it +# ************************************************ + +graph_config = { + "llm": {"model_instance": llm_model_instance}, + "embeddings": {"model_instance": embedder_model_instance} +} + +smart_scraper_graph = SmartScraperGraph( + prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link", + # also accepts a string with the already downloaded HTML code + source="https://www.hmhco.com/event", + config=graph_config +) + +result = smart_scraper_graph.run() +print(result) + +# ************************************************ +# Get graph execution info +# ************************************************ + +graph_exec_info = smart_scraper_graph.get_execution_info() +print(prettify_exec_info(graph_exec_info)) + + diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 91e7fcf6..83b5b712 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -69,6 +69,13 @@ def _set_model_token(self, llm): self.model_token = models_tokens["azure"][llm.model_name] except KeyError: raise KeyError("Model not supported") + + elif 'HuggingFaceEndpoint' in str(type(llm)): + if 'mistral' in llm.repo_id: + try: + self.model_token = models_tokens['mistral'][llm.repo_id] + except KeyError: + raise KeyError("Model not supported") def _create_llm(self, llm_config: dict, chat=False) -> object: @@ -181,7 +188,6 @@ def _create_default_embedder(self) -> object: Raises: ValueError: If the model is not supported. """ - if isinstance(self.llm_model, OpenAI): return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) elif isinstance(self.llm_model, AzureOpenAIEmbeddings): @@ -216,6 +222,9 @@ def _create_embedder(self, embedder_config: dict) -> object: Raises: KeyError: If the model is not supported. """ + + if 'model_instance' in embedder_config: + return embedder_config['model_instance'] # Instantiate the embedding model based on the model name if "openai" in embedder_config["model"]: diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py index a9bab3fc..5bc9a7f8 100644 --- a/scrapegraphai/helpers/models_tokens.py +++ b/scrapegraphai/helpers/models_tokens.py @@ -65,5 +65,8 @@ "mistral.mistral-large-2402-v1:0": 32768, "cohere.embed-english-v3": 512, "cohere.embed-multilingual-v3": 512 + }, + "mistral": { + "mistralai/Mistral-7B-Instruct-v0.2": 32000 } } diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index a5c4e58a..b883845a 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -82,6 +82,8 @@ def execute(self, state: dict) -> dict: if self.verbose: print("--- (updated chunks metadata) ---") + # check if embedder_model is provided, if not use llm_model + self.embedder_model = self.embedder_model if self.embedder_model else self.llm_model embeddings = self.embedder_model retriever = FAISS.from_documents(