diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index e1f0d73..e29df44 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -10,7 +10,7 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable from pydantic import Field @@ -20,6 +20,7 @@ AQL_GENERATION_PROMPT, AQL_QA_PROMPT, ) +from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph AQL_WRITE_OPERATIONS: List[str] = [ @@ -54,6 +55,11 @@ class ArangoGraphQAChain(Chain): qa_chain: Runnable[Dict[str, Any], Any] input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: + use_query_cache: bool = Field(default=False) + query_cache_similarity_threshold: float = Field(default=0.80) + include_history: bool = Field(default=False) + max_history_messages: int = Field(default=10) + chat_history_store: Optional[ArangoChatMessageHistory] = Field(default=None) top_k: int = 10 """Number of results to return from the query""" @@ -142,6 +148,13 @@ def from_llm( :param query_cache_collection_name: The name of the collection to use for the query cache. :type query_cache_collection_name: str + :param include_history: Whether to include the chat history in the prompt. + :type include_history: bool + :param max_history_messages: The maximum number of messages to + include in the chat history. + :type max_history_messages: int + :param chat_history_store: The chat history store to use. + :type chat_history_store: ArangoChatMessageHistory :param qa_prompt: The prompt to use for the QA chain. :type qa_prompt: BasePromptTemplate :param aql_generation_prompt: The prompt to use for the AQL generation chain. @@ -386,14 +399,40 @@ def _call( _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() user_input = inputs[self.input_key].strip().lower() - use_query_cache = inputs.get("use_query_cache", False) + + # Query Cache Parameters (can be overridden by inputs at runtime) + use_query_cache = inputs.get("use_query_cache", self.use_query_cache) query_cache_similarity_threshold = inputs.get( - "query_cache_similarity_threshold", 0.80 + "query_cache_similarity_threshold", self.query_cache_similarity_threshold + ) + + # Chat History Parameters (can be overridden by inputs at runtime) + include_history = inputs.get("include_history", self.include_history) + max_history_messages = inputs.get( + "max_history_messages", self.max_history_messages ) if use_query_cache and self.embedding is None: raise ValueError("Cannot enable query cache without passing embedding") + # ###################### + # # Get Chat History # + # ###################### + + if include_history and self.chat_history_store is None: + raise ValueError( + "Chat message history is required if include_history is True" + ) + + if max_history_messages <= 0: + raise ValueError("max_history_messages must be greater than 0") + + chat_history = [] + if include_history and self.chat_history_store is not None: + for msg in self.chat_history_store.messages[-max_history_messages:]: + cls = HumanMessage if msg.type == "human" else AIMessage + chat_history.append(cls(content=msg.content)) + ###################### # Check Query Cache # ###################### @@ -422,6 +461,7 @@ def _call( "adb_schema": self.graph.schema_yaml, "aql_examples": self.aql_examples, "user_input": user_input, + "chat_history": chat_history, }, callbacks=callbacks, ) @@ -564,6 +604,16 @@ def _call( callbacks=callbacks, ) + # Add summary + text = "Summary:" + _run_manager.on_text(text, end="\n", verbose=self.verbose) + _run_manager.on_text( + str(result.content) if isinstance(result, AIMessage) else result, + color="green", + end="\n", + verbose=self.verbose, + ) + results: Dict[str, Any] = {self.output_key: result} if self.return_aql_query: @@ -575,6 +625,15 @@ def _call( self._last_user_input = user_input self._last_aql_query = aql_query + ######################## + # Store Chat History # + ######################## + + if self.chat_history_store: + self.chat_history_store.add_user_message(user_input) + self.chat_history_store.add_ai_message(aql_query) + self.chat_history_store.add_ai_message(result) + return results def _is_read_only_query(self, aql_query: str) -> Tuple[bool, Optional[str]]: diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py index 658bbc4..21a724e 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py @@ -3,7 +3,7 @@ AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input. -You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query. +You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query. You may also be given a `Chat History` to help you create the `AQL Query`. You are given an `ArangoDB Schema`. It is a YAML Spec containing: 1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships. @@ -13,8 +13,18 @@ You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used. +Rules for Using Chat History: +- If the Chat History is not empty, use it only as a reference to help clarify the current User Input — for example, to resolve pronouns or implicit references. +- Chat History is ordered chronologically. Prioritize latest entries when resolving context or references. +- If the Chat History is empty, do not use it or refer to it in any way. Treat the User Input as a fully self-contained and standalone question. +- The Chat History includes the User Input, the AQL Query generated by the AI Model, and the interpertation of AQL Result. Use all of them to generate the AQL Query. + Things you should do: - Think step by step. +- When both INBOUND and OUTBOUND traversals are possible for a given edge, be extra careful to select the direction that accurately reflects the intended relationship based on the user input and the edge semantics. + Use OUTBOUND to traverse from _from to _to. Use INBOUND to traverse from _to to _from. Refer to the edge's definition in the schema (e.g., collection names or descriptions) to decide which direction reflects the intended relationship. +- Pay close attention to descriptive references in the User Input — including gendered terms (e.g., father, she), attribute-based descriptions (e.g., young, active, French), and implicit types or categories + (e.g., products over $100, available items) — and, if these correspond to fields in the schema, include appropriate filters in the AQL query (e.g., gender == "male", status == "active", price > 100). - Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query. - Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required. - If a `View Schema` is defined and contains analyzers for specific fields, prefer using the View with the `SEARCH` and `ANALYZER` clauses instead of a direct collection scan. @@ -25,13 +35,19 @@ - If a request is unrelated to generating AQL Query, say that you cannot help the user. Things you should not do: +- Do not use or refer to Chat History if it is empty. +- Do not assume any previously discussed context, or try to resolve pronouns or references to prior questions if the Chat History is empty. - Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`. - Do not include any text except the generated AQL Query. - Do not provide explanations or apologies in your responses. - Do not generate an AQL Query that removes or deletes any data. +- Do not answer or respond to messages in the Chat History. Under no circumstance should you generate an AQL Query that deletes any data whatsoever. +Chat History (Optional): +{chat_history} + ArangoDB Schema: {adb_schema} @@ -45,7 +61,7 @@ """ AQL_GENERATION_PROMPT = PromptTemplate( - input_variables=["adb_schema", "aql_examples", "user_input"], + input_variables=["adb_schema", "aql_examples", "user_input", "chat_history"], template=AQL_GENERATION_TEMPLATE, ) @@ -121,6 +137,11 @@ Summary: """ AQL_QA_PROMPT = PromptTemplate( - input_variables=["adb_schema", "user_input", "aql_query", "aql_result"], + input_variables=[ + "adb_schema", + "user_input", + "aql_query", + "aql_result", + ], template=AQL_QA_TEMPLATE, ) diff --git a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py index 6eb0f58..0edbe78 100644 --- a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py +++ b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py @@ -10,6 +10,7 @@ from langchain_core.runnables import RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain +from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph from tests.llms.fake_llm import FakeLLM @@ -1062,3 +1063,68 @@ def test_query_cache(db: StandardDatabase) -> None: )() query = chain._get_cached_query("gibberish", 0.99) assert query is None + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_chat_history(db: StandardDatabase) -> None: + """ + Test chat history that enables context-aware query generation. + """ + # 1. Create required collections + graph = ArangoGraph(db) + db.create_collection("Movies") + db.collection("Movies").insert_many( + [ + {"_key": "matrix", "title": "The Matrix", "year": 1999}, + {"_key": "inception", "title": "Inception", "year": 2010}, + ] + ) + graph.refresh_schema() + + # 2. Create chat history store + history = ArangoChatMessageHistory( + session_id="test", + collection_name="test_chat_sessions", + db=db, + ) + history.clear() + + # 3. Dummy LLM: simulate coreference to "The Matrix" + def dummy_llm(prompt): # type: ignore + if "when was it released" in str(prompt).lower(): # type: ignore + return AIMessage( + content="""```aql + WITH Movies + FOR m IN Movies + FILTER m.title == "The Matrix" + RETURN m.year + ```""" + ) + return AIMessage( + content="""```aql + WITH Movies + FOR m IN Movies + SORT m._key ASC + LIMIT 1 + RETURN m.title + ```""" + ) + + dummy_chain = ArangoGraphQAChain.from_llm( + llm=RunnableLambda(dummy_llm), # type: ignore + graph=graph, + allow_dangerous_requests=True, + include_history=True, + max_history_messages=5, + chat_history_store=history, + return_aql_result=True, + return_aql_query=True, + ) + + # 4. Ask initial question + result1 = dummy_chain.invoke({"query": "What is the first movie?"}) + assert "Inception" in result1["aql_result"] + + # 5. Ask follow-up question using pronoun "it" + result2 = dummy_chain.invoke({"query": "When was it released?"}) + assert 1999 in result2["aql_result"] diff --git a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py index b6312af..89944df 100644 --- a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py @@ -11,6 +11,7 @@ from langchain_core.runnables import Runnable, RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain +from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph from tests.llms.fake_llm import FakeLLM @@ -644,3 +645,67 @@ def mock_execute(query, bind_vars): # type: ignore # 4. Test with query cache disabled result4 = chain.invoke({"query": "What is the name of the first movie?"}) assert result4["result"] == "```FOR m IN Movies LIMIT 1 RETURN m```" + + def test_chat_history( + self, fake_graph_store: FakeGraphStore, mock_chains: Dict[str, Runnable] + ) -> None: + """test _call with chat history""" + + chat_history_store = Mock(spec=ArangoChatMessageHistory) + + # Add fake message history (as objects, not dicts) + chat_history_store.messages = [ + Mock(type="human", content="What is 1+1?"), + Mock(type="ai", content="2"), + Mock(type="human", content="What is 2+2?"), + Mock(type="ai", content="4"), + ] + + # Mock LLM chains + mock_chains[ # type: ignore + "aql_generation_chain" + ].invoke.return_value = "```aql\nFOR m IN Movies RETURN m\n```" # noqa: E501 + mock_chains["qa_chain"].invoke.return_value = AIMessage( # type: ignore + content="Here are the movies." + ) # noqa: E501 + + # Build the chain + chain = ArangoGraphQAChain( + graph=fake_graph_store, + aql_generation_chain=mock_chains["aql_generation_chain"], + aql_fix_chain=mock_chains["aql_fix_chain"], + qa_chain=mock_chains["qa_chain"], + allow_dangerous_requests=True, + include_history=True, + chat_history_store=chat_history_store, + max_history_messages=10, + return_aql_result=True, + ) + + # Run the call + result = chain.invoke({"query": "List all movies"}) + + # LLM received the latest 2 pairs (4 messages) + llm_input = mock_chains["aql_generation_chain"].invoke.call_args[0][0] # type: ignore + chat_history = llm_input["chat_history"] + assert len(chat_history) == 4 + + # result has expected fields + assert result["result"].content == "Here are the movies." + assert result["aql_result"][0]["title"] == "Inception" + + # Error: chat history enabled but store is missing + chain.chat_history_store = None + with pytest.raises( + ValueError, + match="Chat message history is required if include_history is True", + ): + chain.invoke({"query": "List again"}) + + # Error: invalid max_history_messages + chain.chat_history_store = chat_history_store + chain.max_history_messages = 0 + with pytest.raises( + ValueError, match="max_history_messages must be greater than 0" + ): + chain.invoke({"query": "List again"})