From 75255bb4af61211b48fb576e9206640cd60aeefe Mon Sep 17 00:00:00 2001 From: VinciGit00 Date: Thu, 25 Apr 2024 15:08:42 +0200 Subject: [PATCH] complete get_state function --- scrapegraphai/graphs/abstract_graph.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 0433420d..442a809e 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -41,7 +41,7 @@ def _create_llm(self, llm_config: dict): try: self.model_token = models_tokens["openai"][llm_params["model"]] except KeyError: - raise ValueError("Model not supported") + raise KeyError("Model not supported") return OpenAI(llm_params) elif "azure" in llm_params["model"]: @@ -50,14 +50,14 @@ def _create_llm(self, llm_config: dict): try: self.model_token = models_tokens["azure"][llm_params["model"]] except KeyError: - raise ValueError("Model not supported") + raise KeyError("Model not supported") return AzureOpenAI(llm_params) elif "gemini" in llm_params["model"]: try: self.model_token = models_tokens["gemini"][llm_params["model"]] except KeyError: - raise ValueError("Model not supported") + raise KeyError("Model not supported") return Gemini(llm_params) elif "ollama" in llm_params["model"]: @@ -70,19 +70,27 @@ def _create_llm(self, llm_config: dict): try: self.model_token = models_tokens["ollama"][llm_params["model"]] except KeyError: - raise ValueError("Model not supported") + raise KeyError("Model not supported") return Ollama(llm_params) elif "hugging_face" in llm_params["model"]: try: self.model_token = models_tokens["hugging_face"][llm_params["model"]] except KeyError: - raise ValueError("Model not supported") + raise KeyError("Model not supported") return HuggingFace(llm_params) else: raise ValueError( "Model provided by the configuration not supported") + def get_state(self, key=None) -> dict: + """"" + Obtain the current state + """ + if key is not None: + return self.final_state[key] + return self.final_state + def get_execution_info(self): """ Returns the execution information of the graph.