diff --git a/libs/arangodb/.coverage b/libs/arangodb/.coverage index 35611f6..6be0db1 100644 Binary files a/libs/arangodb/.coverage and b/libs/arangodb/.coverage differ diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index f2050af..e1f0d73 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -8,6 +8,7 @@ from arango import AQLQueryExecuteError, AQLQueryExplainError from langchain.chains.base import Chain from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import AIMessage from langchain_core.prompts import BasePromptTemplate @@ -19,7 +20,7 @@ AQL_GENERATION_PROMPT, AQL_QA_PROMPT, ) -from langchain_arangodb.graphs.graph_store import GraphStore +from langchain_arangodb.graphs.arangodb_graph import ArangoGraph AQL_WRITE_OPERATIONS: List[str] = [ "INSERT", @@ -45,7 +46,9 @@ class ArangoGraphQAChain(Chain): See https://python.langchain.com/docs/security for more information. """ - graph: GraphStore = Field(exclude=True) + graph: ArangoGraph = Field(exclude=True) + embedding: Optional[Embeddings] = Field(default=None, exclude=True) + query_cache_collection_name: str = Field(default="Queries") aql_generation_chain: Runnable[Dict[str, Any], Any] aql_fix_chain: Runnable[Dict[str, Any], Any] qa_chain: Runnable[Dict[str, Any], Any] @@ -102,6 +105,8 @@ def __init__(self, **kwargs: Any) -> None: "necessary precautions. " "See https://python.langchain.com/docs/security for more information." ) + self._last_user_input: Optional[str] = None + self._last_aql_query: Optional[str] = None @property def input_keys(self) -> List[str]: @@ -132,6 +137,11 @@ def from_llm( :param llm: The language model to use. :type llm: BaseLanguageModel + :param embedding: The embedding model to use. + :type embedding: Embeddings + :param query_cache_collection_name: The name of the collection + to use for the query cache. + :type query_cache_collection_name: str :param qa_prompt: The prompt to use for the QA chain. :type qa_prompt: BasePromptTemplate :param aql_generation_prompt: The prompt to use for the AQL generation chain. @@ -162,6 +172,170 @@ def from_llm( **kwargs, ) + def _check_and_insert_query(self, text: str, aql: str) -> str: + """ + Check if a query is already in the cache and insert it if it's not. + + :param text: The text of the query to check. + :type text: str + :param aql: The AQL query to check. + :type aql: str + :return: A message indicating the result of the operation. + """ + text = text.strip().lower() + text_hash = self.graph._hash(text) + collection = self.graph.db.collection(self.query_cache_collection_name) + + if collection.has(text_hash): + return f"This query is already in the cache: {text}" + + if self.embedding is None: + raise ValueError("Cannot cache queries without an embedding model.") + + query_embedding = self.embedding.embed_query(text) + collection.insert( + { + "_key": text_hash, + "text": text, + "embedding": query_embedding, + "aql": aql, + } + ) + + return f"Cached: {text}" + + def cache_query(self, text: Optional[str] = None, aql: Optional[str] = None) -> str: + """ + Cache a query generated by the LLM only if it's not already stored. + + :param text: The text of the query to cache. + :param aql: The AQL query to cache. + :return: A message indicating the result of the operation. + """ + if self.embedding is None: + raise ValueError("Cannot cache queries without an embedding model.") + + if not self.graph.db.has_collection(self.query_cache_collection_name): + m = f"Collection {self.query_cache_collection_name} does not exist" # noqa: E501 + raise ValueError(m) + + if not text and aql: + raise ValueError("Text is required to cache a query") + + if text and not aql: + raise ValueError("AQL is required to cache a query") + + if not text and not aql: + if self._last_user_input is None or self._last_aql_query is None: + m = "No previous query to cache. Please provide **text** and **aql**." + raise ValueError(m) + + # Fallback: cache the most recent query + return self._check_and_insert_query( + self._last_user_input, + self._last_aql_query, + ) + + if not isinstance(text, str) or not isinstance(aql, str): + raise ValueError("Text and AQL must be strings") + + return self._check_and_insert_query(text, aql) + + def clear_query_cache(self, text: Optional[str] = None) -> str: + """ + Clear the query cache. + + :param text: The text of the query to delete from the cache. + :type text: str + :return: A message indicating the result of the operation. + """ + + if not self.graph.db.has_collection(self.query_cache_collection_name): + m = f"Collection {self.query_cache_collection_name} does not exist" + raise ValueError(m) + + collection = self.graph.db.collection(self.query_cache_collection_name) + + if text is None: + collection.truncate() + return "Cleared all queries from the cache" + + text = text.strip().lower() + text_hash = self.graph._hash(text) + + if collection.has(text_hash): + collection.delete(text_hash) + return f"Removed: {text}" + + return f"Not found: {text}" + + def _get_cached_query( + self, user_input: str, query_cache_similarity_threshold: float + ) -> Optional[Tuple[str, str]]: + """Get the cached query for the user input. Only used if embedding + is provided and **use_query_cache** is True. + + :param user_input: The user input to search for in the cache. + :type user_input: str + + :return: The cached query and score, if found. + :rtype: Optional[Tuple[str, int]] + """ + if self.embedding is None: + raise ValueError("Cannot enable query cache without passing embedding") + + if self.graph.db.collection(self.query_cache_collection_name).count() == 0: + return None + + user_input = user_input.strip().lower() + + # 1. Exact Search + + query = f""" + FOR q IN {self.query_cache_collection_name} + FILTER q.text == @user_input + LIMIT 1 + RETURN q.aql + """ + + cursor = self.graph.db.aql.execute( + query, + bind_vars={"user_input": user_input}, + ) + + result = list(cursor) # type: ignore + + if result: + return result[0], "1.0" + + # 2. Vector Search + + embedding = self.embedding.embed_query(user_input) + query = """ + FOR q IN @@col + LET score = COSINE_SIMILARITY(q.embedding, @embedding) + SORT score DESC + LIMIT 1 + FILTER score > @score_threshold + RETURN {aql: q.aql, score: score} + """ + + result = list( + self.graph.db.aql.execute( + query, + bind_vars={ + "@col": self.query_cache_collection_name, + "embedding": embedding, # type: ignore + "score_threshold": query_cache_similarity_threshold, # type: ignore + }, + ) + ) + + if result: + return result[0]["aql"], str(round(result[0]["score"], 2)) + + return None + def _call( self, inputs: Dict[str, Any], @@ -211,20 +385,46 @@ def _call( """ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() - user_input = inputs[self.input_key] + user_input = inputs[self.input_key].strip().lower() + use_query_cache = inputs.get("use_query_cache", False) + query_cache_similarity_threshold = inputs.get( + "query_cache_similarity_threshold", 0.80 + ) + + if use_query_cache and self.embedding is None: + raise ValueError("Cannot enable query cache without passing embedding") ###################### - # Generate AQL Query # + # Check Query Cache # ###################### - aql_generation_output = self.aql_generation_chain.invoke( - { - "adb_schema": self.graph.schema_yaml, - "aql_examples": self.aql_examples, - "user_input": user_input, - }, - callbacks=callbacks, - ) + cached_query, score = None, None + if use_query_cache: + if self.embedding is None: + m = "Embedding must be provided when using query cache" + raise ValueError(m) + + if not self.graph.db.has_collection(self.query_cache_collection_name): + self.graph.db.create_collection(self.query_cache_collection_name) + + cache_result = self._get_cached_query( + user_input, query_cache_similarity_threshold + ) + + if cache_result is not None: + cached_query, score = cache_result + + if cached_query: + aql_generation_output = f"```aql{cached_query}```" + else: + aql_generation_output = self.aql_generation_chain.invoke( + { + "adb_schema": self.graph.schema_yaml, + "aql_examples": self.aql_examples, + "user_input": user_input, + }, + callbacks=callbacks, + ) aql_query = "" aql_error = "" @@ -283,9 +483,14 @@ def _call( """ raise ValueError(error_msg) - _run_manager.on_text( - f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose - ) + query_message = f"AQL Query ({aql_generation_attempt})\n" + if cached_query: + score_string = score if score is not None else "1.0" + query_message = ( + f"AQL Query (used cached query, score: {score_string})\n" # noqa: E501 + ) + + _run_manager.on_text(query_message, verbose=self.verbose) _run_manager.on_text( aql_query, color="green", end="\n", verbose=self.verbose ) @@ -300,7 +505,6 @@ def _call( "list_limit": self.output_list_limit, "string_limit": self.output_string_limit, } - aql_result = aql_execution_func(aql_query, params) except (AQLQueryExecuteError, AQLQueryExplainError) as e: aql_error = str(e.error_message) @@ -368,6 +572,9 @@ def _call( if self.return_aql_result: results["aql_result"] = aql_result + self._last_user_input = user_input + self._last_aql_query = aql_query + return results def _is_read_only_query(self, aql_query: str) -> Tuple[bool, Optional[str]]: diff --git a/libs/arangodb/langchain_arangodb/graphs/arangodb_graph.py b/libs/arangodb/langchain_arangodb/graphs/arangodb_graph.py index 797ade1..0eba47c 100644 --- a/libs/arangodb/langchain_arangodb/graphs/arangodb_graph.py +++ b/libs/arangodb/langchain_arangodb/graphs/arangodb_graph.py @@ -424,7 +424,7 @@ def query( top_k = params.pop("top_k", None) list_limit = params.pop("list_limit", 32) string_limit = params.pop("string_limit", 256) - cursor = self.__db.aql.execute(query, **params) + cursor = self.db.aql.execute(query, **params) results = [] @@ -454,7 +454,7 @@ def explain(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: :raises ArangoServerError: If the ArangoDB server cannot be reached. :raises ArangoCollectionError: If the collection cannot be created. """ - return self.__db.aql.explain(query) # type: ignore + return self.db.aql.explain(query) # type: ignore def add_graph_documents( self, diff --git a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py index 9805780..6eb0f58 100644 --- a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py +++ b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py @@ -7,6 +7,7 @@ from arango.database import StandardDatabase from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import AIMessage +from langchain_core.runnables import RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain from langchain_arangodb.graphs.arangodb_graph import ArangoGraph @@ -888,3 +889,176 @@ def test_init_succeeds_if_dangerous_requests_allowed() -> None: "ValueError was raised unexpectedly when \ allow_dangerous_requests=True" ) + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_query_cache(db: StandardDatabase) -> None: + """Test query cache in 5 situations: + 1. Exact search + 2. Vector search + 3. AQL generation and store new query + 4. Query cache disabled + 5. Query cache without embedding + """ + graph = ArangoGraph(db) + graph.db.create_collection("Movies") + graph.db.create_collection("Queries") + + queries = [ + { + "text": "List all movies", + "aql": "FOR m IN Movies RETURN m", + "embedding": [0.123, 0.456, 0.789, 0.321, 0.654], + } + ] + graph.db.collection("Queries").insert_many(queries) + + movies = [ + {"title": "The Matrix"}, + {"title": "Inception"}, + {"title": "Interstellar"}, + ] + graph.db.collection("Movies").insert_many(movies) + graph.refresh_schema() + + dummy_llm = RunnableLambda(lambda prompt: "```FOR m IN Movies LIMIT 1 RETURN m```") + + chain = ArangoGraphQAChain.from_llm( + llm=dummy_llm, # type: ignore + graph=graph, + verbose=True, + allow_dangerous_requests=True, + return_aql_result=True, + query_cache_collection_name="Queries", + ) + + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.123] * 5)} + )() + + # 1. Test with exact search + result1 = chain.invoke({"query": "List all movies", "use_query_cache": True}) + assert [m["title"] for m in result1["aql_result"]] == [ + "The Matrix", + "Inception", + "Interstellar", + ] + + # 2. Test with vector search + result2 = chain.invoke( + { + "query": "Show me all movies", + "use_query_cache": True, + "query_cache_similarity_threshold": 0.80, + } + ) + assert [m["title"] for m in result2["aql_result"]] == [ + "The Matrix", + "Inception", + "Interstellar", + ] + + # 3. Test with aql generation and store new query + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [1, 0, 0, 0, 0])} + )() + result3 = chain.invoke( + {"query": "What is the name of the first movie?", "use_query_cache": True} + ) + chain.cache_query() + assert result3["aql_result"][0]["title"] == "The Matrix" + assert len(graph.db.collection("Queries").all()) == 2 # type: ignore + + # 4. Test with query cache disabled + result4 = chain.invoke({"query": "What is the name of the first movie?"}) + assert result4["aql_result"][0]["title"] == "The Matrix" + + # 5. Test with query cache without embedding + chain.embedding = None + with pytest.raises( + ValueError, match="Cannot enable query cache without passing embedding" + ): + chain.invoke({"query": "List all movies", "use_query_cache": True}) + + # 6. Test _check_and_insert_query + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.111] * 5)} + )() + + # Insert a new query + msg1 = chain._check_and_insert_query( + "Find sci-fi movies", "FOR m IN Movies FILTER m.genre == 'sci-fi' RETURN m" + ) + assert msg1.startswith("Cached:") + + # Re-insert the same query -> should detect duplicate + msg2 = chain._check_and_insert_query( + "Find sci-fi movies", "FOR m IN Movies FILTER m.genre == 'sci-fi' RETURN m" + ) + assert msg2.startswith("This query is already in the cache") + + # 7. Test cache_query + + # Fallback to _last_user_input/_last_aql_query + chain._last_user_input = "List animated movies" + chain._last_aql_query = "FOR m IN Movies FILTER m.genre == 'animation' RETURN m" + msg = chain.cache_query() + assert msg == "Cached: list animated movies" + + # Missing text, aql or embedding should raise + with pytest.raises(ValueError, match="Text is required to cache a query"): + chain.cache_query(aql="FOR m IN Movies RETURN m") + + with pytest.raises(ValueError, match="AQL is required to cache a query"): + chain.cache_query(text="List movies") + + # 8. Test clear_query_cache + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.111] * 5)} + )() + + # Add a test query + chain.cache_query(text="Temp query", aql="FOR m IN Movies RETURN m") + assert graph.db.collection("Queries").has(graph._hash("temp query")) + + # Delete specific query + msg = chain.clear_query_cache(text="Temp query") + assert msg == "Removed: temp query" + assert not graph.db.collection("Queries").has(graph._hash("temp query")) + + # Clear all + msg = chain.clear_query_cache() + assert msg == "Cleared all queries from the cache" + assert graph.db.collection("Queries").count() == 0 + + # 9. Test _get_cached_query + # Insert two queries + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.123] * 5)} + )() + chain.cache_query(text="List all movies", aql="FOR m IN Movies RETURN m") + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.124] * 5)} + )() + chain.cache_query(text="Show all movies", aql="FOR m IN Movies RETURN m") + + # Exact match + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.123] * 5)} + )() + query = chain._get_cached_query("list all movies", 0.8) + assert query[0].startswith("FOR m IN Movies RETURN") # type: ignore + + # Vector match + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.124] * 5)} + )() + query = chain._get_cached_query("show all movies", 0.8) + assert query[0].startswith("FOR m IN Movies RETURN") # type: ignore + + # No match + chain.embedding = type( + "FakeEmbedding", (), {"embed_query": staticmethod(lambda text: [0.0] * 5)} + )() + query = chain._get_cached_query("gibberish", 0.99) + assert query is None diff --git a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py index f7aff6a..b6312af 100644 --- a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py @@ -1,5 +1,6 @@ """Unit tests for ArangoGraphQAChain.""" +import math from typing import Any, Dict, List from unittest.mock import MagicMock, Mock @@ -7,15 +8,15 @@ from arango import AQLQueryExecuteError from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.messages import AIMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain -from langchain_arangodb.graphs.graph_store import GraphStore +from langchain_arangodb.graphs.arangodb_graph import ArangoGraph from tests.llms.fake_llm import FakeLLM -class FakeGraphStore(GraphStore): - """A fake GraphStore implementation for testing purposes.""" +class FakeGraphStore(ArangoGraph): + """A fake ArangoGraph implementation for testing purposes.""" def __init__(self) -> None: self._schema_yaml = "node_props:\n Movie:\n - property: title\n type: STRING" @@ -27,10 +28,24 @@ def __init__(self) -> None: self.refreshed = False self.graph_documents_added = [] # type: ignore + # Mock the database interface + self.__db = Mock() + self.__db.collection = Mock() + mock_queries_collection = Mock() + mock_queries_collection.find = Mock(return_value=[]) + mock_queries_collection.insert = Mock() + self.__db.collection.return_value = mock_queries_collection + self.__db.aql = Mock() + self.__db.aql.execute = Mock(return_value=[]) + @property def schema_yaml(self) -> str: return self._schema_yaml + @property + def db(self) -> Mock: # type: ignore + return self.__db # type: ignore + @property def schema_json(self) -> str: return self._schema_json @@ -43,7 +58,7 @@ def explain(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: self.explains_run.append((query, params)) return [{"plan": "This is a fake AQL query plan."}] - def refresh_schema(self) -> None: + def refresh_schema(self) -> None: # type: ignore self.refreshed = True def add_graph_documents( # type: ignore @@ -67,6 +82,26 @@ def fake_llm(self) -> FakeLLM: """Create a fake LLM.""" return FakeLLM() + @pytest.fixture + def mock_embedding(self) -> Mock: + """Create a mock embedding model.""" + mock_emb = Mock() + mock_emb.embed_query = Mock( + return_value=[0.1, 0.2, 0.3] + ) # Simple mock embedding vector + return mock_emb + + @pytest.fixture + def mock_db_with_queries(self) -> Mock: + """Create a mock database with a Queries collection.""" + mock_db = Mock() + mock_collection = Mock() + mock_collection.find = Mock(return_value=[]) # Default to empty results + mock_db.collection = Mock(return_value=mock_collection) + mock_db.aql = Mock() + mock_db.aql.execute = Mock(return_value=[]) # Default to empty results + return mock_db + @pytest.fixture def mock_chains(self) -> Dict[str, Runnable]: """Create mock chains that correctly implement the Runnable abstract class.""" @@ -528,3 +563,84 @@ def test_call_with_callback_manager( assert "result" in result assert mock_run_manager.get_child.called + + def test_query_cache(self, fake_graph_store: FakeGraphStore) -> None: + """Test query cache in 4 situations: + 1. Exact search + 2. Vector search + 3. AQL generation and store new query + 4. Query cache disabled + """ + + def cosine_similarity(v1, v2): # type: ignore + dot = sum(a * b for a, b in zip(v1, v2)) + norm1 = math.sqrt(sum(a * a for a in v1)) + norm2 = math.sqrt(sum(b * b for b in v2)) + if norm1 == 0.0 or norm2 == 0.0: + return 0.0 + return dot / (norm1 * norm2) + + fake_graph_store.db.collection("Queries").find = Mock(return_value=[]) + + # Simulate a previously cached query + stored_query = { + "text": "List all movies", + "aql": "FOR m IN Movies RETURN m", + "embedding": [0.123, 0.456, 0.789, 0.321, 0.654], + } + + def mock_execute(query, bind_vars): # type: ignore + # Handle vector search query + if "query_embedding" in bind_vars: + query_embedding = bind_vars["query_embedding"] + score = cosine_similarity(query_embedding, stored_query["embedding"]) + if score > bind_vars.get("score_threshold", 0.75): + return [{"aql": stored_query["aql"], "score": score}] + return [] + # Handle regular query execution + return [{"title": "Inception"}] + + fake_graph_store.db.aql.execute = Mock(side_effect=mock_execute) + + dummy_llm = RunnableLambda( + lambda prompt: "```FOR m IN Movies LIMIT 1 RETURN m```" + ) + + chain = ArangoGraphQAChain.from_llm( + llm=dummy_llm, # type: ignore + graph=fake_graph_store, + allow_dangerous_requests=True, + verbose=True, + return_aql_result=True, + ) + + chain.embedding = type( + "FakeEmbedding", + (), + { + "embed_query": staticmethod( + lambda text: { + "Find all movies": [0.123, 0.456, 0.789, 0.321, 0.654], + "Show me all movies": [0.120, 0.460, 0.780, 0.330, 0.650], + }.get(text, [0.0] * 5) + ) + }, + )() + + # 1. Test with exact search + result1 = chain.invoke({"query": "Find all movies", "use_query_cache": True}) + assert result1["aql_result"][0]["title"] == "Inception" + + # 2. Test with vector search + result2 = chain.invoke({"query": "Show me all movies", "use_query_cache": True}) + assert result2["aql_result"][0]["title"] == "Inception" + + # 3. Test with aql generation and store new query + result3 = chain.invoke( + {"query": "What is the name of the first movie?", "use_query_cache": True} + ) + assert result3["result"] == "```FOR m IN Movies LIMIT 1 RETURN m```" + + # 4. Test with query cache disabled + result4 = chain.invoke({"query": "What is the name of the first movie?"}) + assert result4["result"] == "```FOR m IN Movies LIMIT 1 RETURN m```"