From c47426a716b76174743308591311a7d248ffffdc Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Tue, 19 Aug 2025 15:20:29 -0700 Subject: [PATCH 01/10] drop ChatMessageHistory --- .../chains/graph_qa/arangodb.py | 43 ++++++++++++------- .../chains/graph_qa/prompts.py | 2 +- .../chains/test_graph_database.py | 20 +++------ .../tests/unit_tests/chains/test_graph_qa.py | 36 +++++++--------- 4 files changed, 50 insertions(+), 51 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index e29df44..47ab33f 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -10,7 +10,7 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable from pydantic import Field @@ -20,7 +20,6 @@ AQL_GENERATION_PROMPT, AQL_QA_PROMPT, ) -from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph AQL_WRITE_OPERATIONS: List[str] = [ @@ -59,7 +58,7 @@ class ArangoGraphQAChain(Chain): query_cache_similarity_threshold: float = Field(default=0.80) include_history: bool = Field(default=False) max_history_messages: int = Field(default=10) - chat_history_store: Optional[ArangoChatMessageHistory] = Field(default=None) + chat_history_collection_name: str = Field(default="ChatHistory") top_k: int = 10 """Number of results to return from the query""" @@ -419,19 +418,29 @@ def _call( # # Get Chat History # # ###################### - if include_history and self.chat_history_store is None: - raise ValueError( - "Chat message history is required if include_history is True" - ) + if not self.graph.db.has_collection(self.chat_history_collection_name): + self.graph.db.create_collection(self.chat_history_collection_name) if max_history_messages <= 0: raise ValueError("max_history_messages must be greater than 0") chat_history = [] - if include_history and self.chat_history_store is not None: - for msg in self.chat_history_store.messages[-max_history_messages:]: - cls = HumanMessage if msg.type == "human" else AIMessage - chat_history.append(cls(content=msg.content)) + if include_history: + aql = f""" + FOR doc IN {self.chat_history_collection_name} + SORT doc._key DESC + LIMIT @n + RETURN {{ + user_input: doc.user_input, + aql_query: doc.aql_query, + result: doc.result + }} + """ + cursor = self.graph.db.aql.execute( + aql, + bind_vars={"n": self.max_history_messages}, # type: ignore + ) + chat_history = [d for d in cursor][::-1] # type: ignore ###################### # Check Query Cache # @@ -629,10 +638,14 @@ def _call( # Store Chat History # ######################## - if self.chat_history_store: - self.chat_history_store.add_user_message(user_input) - self.chat_history_store.add_ai_message(aql_query) - self.chat_history_store.add_ai_message(result) + self.graph.db.insert_document( + self.chat_history_collection_name, + { + "user_input": user_input, + "aql_query": aql_query, + "result": result.content if isinstance(result, AIMessage) else result, + }, + ) return results diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py index 21a724e..9d6ff3c 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py @@ -15,9 +15,9 @@ Rules for Using Chat History: - If the Chat History is not empty, use it only as a reference to help clarify the current User Input — for example, to resolve pronouns or implicit references. +- Each entry in Chat History includes the User Input, the AQL Query generated by the AI Model, and the AQL Result. Use all of them to generate the AQL Query. - Chat History is ordered chronologically. Prioritize latest entries when resolving context or references. - If the Chat History is empty, do not use it or refer to it in any way. Treat the User Input as a fully self-contained and standalone question. -- The Chat History includes the User Input, the AQL Query generated by the AI Model, and the interpertation of AQL Result. Use all of them to generate the AQL Query. Things you should do: - Think step by step. diff --git a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py index 0edbe78..7202c3e 100644 --- a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py +++ b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py @@ -10,7 +10,6 @@ from langchain_core.runnables import RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain -from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph from tests.llms.fake_llm import FakeLLM @@ -1081,15 +1080,7 @@ def test_chat_history(db: StandardDatabase) -> None: ) graph.refresh_schema() - # 2. Create chat history store - history = ArangoChatMessageHistory( - session_id="test", - collection_name="test_chat_sessions", - db=db, - ) - history.clear() - - # 3. Dummy LLM: simulate coreference to "The Matrix" + # 2. Dummy LLM: simulate coreference to "The Matrix" def dummy_llm(prompt): # type: ignore if "when was it released" in str(prompt).lower(): # type: ignore return AIMessage( @@ -1116,15 +1107,16 @@ def dummy_llm(prompt): # type: ignore allow_dangerous_requests=True, include_history=True, max_history_messages=5, - chat_history_store=history, return_aql_result=True, return_aql_query=True, ) - # 4. Ask initial question - result1 = dummy_chain.invoke({"query": "What is the first movie?"}) + # 3. Ask initial question + result1 = dummy_chain.invoke( + {"query": "What is the first movie?", "include_history": False} + ) assert "Inception" in result1["aql_result"] - # 5. Ask follow-up question using pronoun "it" + # 4. Ask follow-up question using pronoun "it" result2 = dummy_chain.invoke({"query": "When was it released?"}) assert 1999 in result2["aql_result"] diff --git a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py index 89944df..13745a6 100644 --- a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py @@ -11,7 +11,6 @@ from langchain_core.runnables import Runnable, RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain -from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph from tests.llms.fake_llm import FakeLLM @@ -651,14 +650,17 @@ def test_chat_history( ) -> None: """test _call with chat history""" - chat_history_store = Mock(spec=ArangoChatMessageHistory) - - # Add fake message history (as objects, not dicts) - chat_history_store.messages = [ - Mock(type="human", content="What is 1+1?"), - Mock(type="ai", content="2"), - Mock(type="human", content="What is 2+2?"), - Mock(type="ai", content="4"), + chat_history = [ + { + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + }, + { + "user_input": "What is 2+2?", + "aql_query": "RETURN 2+2", + "result": "4", + }, ] # Mock LLM chains @@ -669,6 +671,8 @@ def test_chat_history( content="Here are the movies." ) # noqa: E501 + fake_graph_store.db.aql.execute.side_effect = [chat_history] + # Build the chain chain = ArangoGraphQAChain( graph=fake_graph_store, @@ -677,7 +681,6 @@ def test_chat_history( qa_chain=mock_chains["qa_chain"], allow_dangerous_requests=True, include_history=True, - chat_history_store=chat_history_store, max_history_messages=10, return_aql_result=True, ) @@ -685,25 +688,16 @@ def test_chat_history( # Run the call result = chain.invoke({"query": "List all movies"}) - # LLM received the latest 2 pairs (4 messages) + # LLM received the latest 2 docs llm_input = mock_chains["aql_generation_chain"].invoke.call_args[0][0] # type: ignore chat_history = llm_input["chat_history"] - assert len(chat_history) == 4 + assert len(chat_history) == 2 # result has expected fields assert result["result"].content == "Here are the movies." assert result["aql_result"][0]["title"] == "Inception" - # Error: chat history enabled but store is missing - chain.chat_history_store = None - with pytest.raises( - ValueError, - match="Chat message history is required if include_history is True", - ): - chain.invoke({"query": "List again"}) - # Error: invalid max_history_messages - chain.chat_history_store = chat_history_store chain.max_history_messages = 0 with pytest.raises( ValueError, match="max_history_messages must be greater than 0" From 138a9b74a11fe33ecff8afc13a847d18b0f224b5 Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Tue, 19 Aug 2025 16:45:16 -0700 Subject: [PATCH 02/10] change back to ChatMessageHistory --- .../chains/graph_qa/arangodb.py | 35 ++++++++++----- .../chat_message_histories/arangodb.py | 4 ++ .../chains/test_graph_database.py | 20 ++++++--- .../tests/unit_tests/chains/test_graph_qa.py | 44 ++++++++++++++++--- 4 files changed, 79 insertions(+), 24 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index 47ab33f..d62cf05 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -20,6 +20,7 @@ AQL_GENERATION_PROMPT, AQL_QA_PROMPT, ) +from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph AQL_WRITE_OPERATIONS: List[str] = [ @@ -58,7 +59,7 @@ class ArangoGraphQAChain(Chain): query_cache_similarity_threshold: float = Field(default=0.80) include_history: bool = Field(default=False) max_history_messages: int = Field(default=10) - chat_history_collection_name: str = Field(default="ChatHistory") + chat_history_store: Optional[ArangoChatMessageHistory] = Field(default=None) top_k: int = 10 """Number of results to return from the query""" @@ -418,16 +419,24 @@ def _call( # # Get Chat History # # ###################### - if not self.graph.db.has_collection(self.chat_history_collection_name): - self.graph.db.create_collection(self.chat_history_collection_name) + if include_history and self.chat_history_store is None: + raise ValueError( + "Chat message history is required if include_history is True" + ) if max_history_messages <= 0: raise ValueError("max_history_messages must be greater than 0") + if ( + self.chat_history_store is None + or self.chat_history_store._collection_name is None + ): + raise ValueError("Chat history store is not initialized") + chat_history = [] if include_history: aql = f""" - FOR doc IN {self.chat_history_collection_name} + FOR doc IN {self.chat_history_store._collection_name} SORT doc._key DESC LIMIT @n RETURN {{ @@ -638,14 +647,16 @@ def _call( # Store Chat History # ######################## - self.graph.db.insert_document( - self.chat_history_collection_name, - { - "user_input": user_input, - "aql_query": aql_query, - "result": result.content if isinstance(result, AIMessage) else result, - }, - ) + if self.chat_history_store is not None: + self.chat_history_store.add_doc( + { + "user_input": user_input, + "aql_query": aql_query, + "result": result.content + if isinstance(result, AIMessage) + else result, + } + ) return results diff --git a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py index 49a7057..d5333cb 100644 --- a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py @@ -173,6 +173,10 @@ def add_message(self, message: BaseMessage) -> None: }, ) + def add_doc(self, doc: dict[str, Any]) -> None: + """Add a list of documents to the chat history.""" + self._db.insert_document(self._collection_name, doc) + def clear(self) -> None: """Clear session memory from ArangoDB. diff --git a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py index 7202c3e..0edbe78 100644 --- a/libs/arangodb/tests/integration_tests/chains/test_graph_database.py +++ b/libs/arangodb/tests/integration_tests/chains/test_graph_database.py @@ -10,6 +10,7 @@ from langchain_core.runnables import RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain +from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph from tests.llms.fake_llm import FakeLLM @@ -1080,7 +1081,15 @@ def test_chat_history(db: StandardDatabase) -> None: ) graph.refresh_schema() - # 2. Dummy LLM: simulate coreference to "The Matrix" + # 2. Create chat history store + history = ArangoChatMessageHistory( + session_id="test", + collection_name="test_chat_sessions", + db=db, + ) + history.clear() + + # 3. Dummy LLM: simulate coreference to "The Matrix" def dummy_llm(prompt): # type: ignore if "when was it released" in str(prompt).lower(): # type: ignore return AIMessage( @@ -1107,16 +1116,15 @@ def dummy_llm(prompt): # type: ignore allow_dangerous_requests=True, include_history=True, max_history_messages=5, + chat_history_store=history, return_aql_result=True, return_aql_query=True, ) - # 3. Ask initial question - result1 = dummy_chain.invoke( - {"query": "What is the first movie?", "include_history": False} - ) + # 4. Ask initial question + result1 = dummy_chain.invoke({"query": "What is the first movie?"}) assert "Inception" in result1["aql_result"] - # 4. Ask follow-up question using pronoun "it" + # 5. Ask follow-up question using pronoun "it" result2 = dummy_chain.invoke({"query": "When was it released?"}) assert 1999 in result2["aql_result"] diff --git a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py index 13745a6..f736f97 100644 --- a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py @@ -11,6 +11,7 @@ from langchain_core.runnables import Runnable, RunnableLambda from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain +from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory from langchain_arangodb.graphs.arangodb_graph import ArangoGraph from tests.llms.fake_llm import FakeLLM @@ -650,18 +651,41 @@ def test_chat_history( ) -> None: """test _call with chat history""" - chat_history = [ + chat_history_store = Mock(spec=ArangoChatMessageHistory) + + chat_history_store._collection_name = "ChatHistory" + + # Add fake message history + chat_history_store.add_doc( { "user_input": "What is 1+1?", "aql_query": "RETURN 1+1", "result": "2", - }, + } + ) + + chat_history_store.add_doc( { "user_input": "What is 2+2?", "aql_query": "RETURN 2+2", "result": "4", - }, - ] + } + ) + + fake_graph_store.db.aql.execute.return_value = iter( + [ + { + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + }, + { + "user_input": "What is 2+2?", + "aql_query": "RETURN 2+2", + "result": "4", + }, + ] + ) # Mock LLM chains mock_chains[ # type: ignore @@ -671,8 +695,6 @@ def test_chat_history( content="Here are the movies." ) # noqa: E501 - fake_graph_store.db.aql.execute.side_effect = [chat_history] - # Build the chain chain = ArangoGraphQAChain( graph=fake_graph_store, @@ -681,6 +703,7 @@ def test_chat_history( qa_chain=mock_chains["qa_chain"], allow_dangerous_requests=True, include_history=True, + chat_history_store=chat_history_store, max_history_messages=10, return_aql_result=True, ) @@ -697,7 +720,16 @@ def test_chat_history( assert result["result"].content == "Here are the movies." assert result["aql_result"][0]["title"] == "Inception" + # Error: chat history enabled but store is missing + chain.chat_history_store = None + with pytest.raises( + ValueError, + match="Chat message history is required if include_history is True", + ): + chain.invoke({"query": "List again"}) + # Error: invalid max_history_messages + chain.chat_history_store = chat_history_store chain.max_history_messages = 0 with pytest.raises( ValueError, match="max_history_messages must be greater than 0" From 3b65f873b5d5cbe4625e4a454536b64d9a5ccba8 Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Tue, 19 Aug 2025 19:14:11 -0700 Subject: [PATCH 03/10] add feedback to chat history --- .../chains/graph_qa/arangodb.py | 27 ++++++++++--------- .../chains/graph_qa/prompts.py | 3 ++- .../chat_message_histories/arangodb.py | 2 +- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index d62cf05..03ea9b0 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -427,24 +427,25 @@ def _call( if max_history_messages <= 0: raise ValueError("max_history_messages must be greater than 0") - if ( - self.chat_history_store is None - or self.chat_history_store._collection_name is None - ): - raise ValueError("Chat history store is not initialized") - chat_history = [] + collection_name = self.chat_history_store._collection_name # type: ignore if include_history: aql = f""" - FOR doc IN {self.chat_history_store._collection_name} + FOR doc IN {collection_name} SORT doc._key DESC LIMIT @n - RETURN {{ - user_input: doc.user_input, - aql_query: doc.aql_query, - result: doc.result - }} - """ + RETURN + HAS(doc, "role") ? {{ + type: "feedback", + content: doc.content + }} : + {{ + type: "query-response pair", + user_input: doc.user_input, + aql_query: doc.aql_query, + result: doc.result + }} + """ cursor = self.graph.db.aql.execute( aql, bind_vars={"n": self.max_history_messages}, # type: ignore diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py index 9d6ff3c..de45a5a 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py @@ -15,7 +15,8 @@ Rules for Using Chat History: - If the Chat History is not empty, use it only as a reference to help clarify the current User Input — for example, to resolve pronouns or implicit references. -- Each entry in Chat History includes the User Input, the AQL Query generated by the AI Model, and the AQL Result. Use all of them to generate the AQL Query. +- If the Chat History entry is a query-response pair containing User Input, AQL Query, and AQL Result, use all of them to generate the AQL Query. +- If the Chat History entry is a feedback message, use it to improve the AQL Query. Do not use it to generate the AQL Query. - Chat History is ordered chronologically. Prioritize latest entries when resolving context or references. - If the Chat History is empty, do not use it or refer to it in any way. Treat the User Input as a fully self-contained and standalone question. diff --git a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py index d5333cb..1d5abfe 100644 --- a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py @@ -174,7 +174,7 @@ def add_message(self, message: BaseMessage) -> None: ) def add_doc(self, doc: dict[str, Any]) -> None: - """Add a list of documents to the chat history.""" + """Add a dict of message to the chat history.""" self._db.insert_document(self._collection_name, doc) def clear(self) -> None: From 40da2a7f31944c6cb254afba96f170ce0c717dd1 Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Tue, 19 Aug 2025 19:32:35 -0700 Subject: [PATCH 04/10] fix no attr err --- libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index 03ea9b0..76d2d1f 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -427,8 +427,10 @@ def _call( if max_history_messages <= 0: raise ValueError("max_history_messages must be greater than 0") + if self.chat_history_store is not None: + collection_name = self.chat_history_store._collection_name # type: ignore + chat_history = [] - collection_name = self.chat_history_store._collection_name # type: ignore if include_history: aql = f""" FOR doc IN {collection_name} From 159774d0031a5aa9de94dd9c0cdbf25c82b716a1 Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Fri, 22 Aug 2025 17:09:01 -0700 Subject: [PATCH 05/10] add_doc -> add_qa_message --- .../chains/graph_qa/arangodb.py | 24 ++++--------------- .../chains/graph_qa/prompts.py | 4 ++-- .../chat_message_histories/arangodb.py | 13 +++++++--- .../tests/unit_tests/chains/test_graph_qa.py | 18 ++++---------- 4 files changed, 21 insertions(+), 38 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index 76d2d1f..802946f 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -436,17 +436,7 @@ def _call( FOR doc IN {collection_name} SORT doc._key DESC LIMIT @n - RETURN - HAS(doc, "role") ? {{ - type: "feedback", - content: doc.content - }} : - {{ - type: "query-response pair", - user_input: doc.user_input, - aql_query: doc.aql_query, - result: doc.result - }} + RETURN UNSET(doc, ["_id", "_key", "_rev", "session_id"]) """ cursor = self.graph.db.aql.execute( aql, @@ -651,14 +641,10 @@ def _call( ######################## if self.chat_history_store is not None: - self.chat_history_store.add_doc( - { - "user_input": user_input, - "aql_query": aql_query, - "result": result.content - if isinstance(result, AIMessage) - else result, - } + self.chat_history_store.add_qa_message( + user_input, aql_query, result.content + if isinstance(result, AIMessage) + else result ) return results diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py index de45a5a..e2c1fe6 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py @@ -15,8 +15,8 @@ Rules for Using Chat History: - If the Chat History is not empty, use it only as a reference to help clarify the current User Input — for example, to resolve pronouns or implicit references. -- If the Chat History entry is a query-response pair containing User Input, AQL Query, and AQL Result, use all of them to generate the AQL Query. -- If the Chat History entry is a feedback message, use it to improve the AQL Query. Do not use it to generate the AQL Query. +- If the Chat History entry has a role of "qa" which contains User Input, AQL Query, and AQL Result, use all of them to generate the AQL Query. +- If the Chat History entry has a role of "human", use it as feedback to improve the AQL Query. Do not use it to generate the AQL Query. - Chat History is ordered chronologically. Prioritize latest entries when resolving context or references. - If the Chat History is empty, do not use it or refer to it in any way. Treat the User Input as a fully self-contained and standalone question. diff --git a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py index 1d5abfe..951f562 100644 --- a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py @@ -173,9 +173,16 @@ def add_message(self, message: BaseMessage) -> None: }, ) - def add_doc(self, doc: dict[str, Any]) -> None: - """Add a dict of message to the chat history.""" - self._db.insert_document(self._collection_name, doc) + def add_qa_message(self, user_input: str, aql_query: str, result: str) -> None: + """Add a QA message to the chat history.""" + self._db.collection(self._collection_name).insert( + { + "role": "qa", + "user_input": user_input, + "aql_query": aql_query, + "result": result, + }, + ) def clear(self) -> None: """Clear session memory from ArangoDB. diff --git a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py index f736f97..531523b 100644 --- a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py @@ -656,20 +656,10 @@ def test_chat_history( chat_history_store._collection_name = "ChatHistory" # Add fake message history - chat_history_store.add_doc( - { - "user_input": "What is 1+1?", - "aql_query": "RETURN 1+1", - "result": "2", - } - ) - - chat_history_store.add_doc( - { - "user_input": "What is 2+2?", - "aql_query": "RETURN 2+2", - "result": "4", - } + chat_history_store.add_qa_message( + user_input="What is 1+1?", + aql_query="RETURN 1+1", + result="2", ) fake_graph_store.db.aql.execute.return_value = iter( From 3ed787527c041de29c35a25433633b4b8a239090 Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Fri, 22 Aug 2025 17:10:30 -0700 Subject: [PATCH 06/10] fix lint err --- .../arangodb/langchain_arangodb/chains/graph_qa/arangodb.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index 802946f..c8f218a 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -642,9 +642,9 @@ def _call( if self.chat_history_store is not None: self.chat_history_store.add_qa_message( - user_input, aql_query, result.content - if isinstance(result, AIMessage) - else result + user_input, + aql_query, + result.content if isinstance(result, AIMessage) else result, # type: ignore ) return results From ace310c0e39a82ef61e8a545acb9f0ce5007631c Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Tue, 26 Aug 2025 12:56:50 -0700 Subject: [PATCH 07/10] add get_messages --- .../chains/graph_qa/arangodb.py | 17 ++--------- .../chat_message_histories/arangodb.py | 24 +++++++++++++-- .../tests/unit_tests/chains/test_graph_qa.py | 30 ++++++++++--------- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index 2a17a83..11e526b 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -427,22 +427,11 @@ def _call( if max_history_messages <= 0: raise ValueError("max_history_messages must be greater than 0") - if self.chat_history_store is not None: - collection_name = self.chat_history_store._collection_name # type: ignore - chat_history = [] - if include_history: - aql = f""" - FOR doc IN {collection_name} - SORT doc._key DESC - LIMIT @n - RETURN UNSET(doc, ["_id", "_key", "_rev", "session_id"]) - """ - cursor = self.graph.db.aql.execute( - aql, - bind_vars={"n": self.max_history_messages}, # type: ignore + if include_history and self.chat_history_store is not None: + chat_history.extend( + self.chat_history_store.get_messages(n_messages=max_history_messages) ) - chat_history = [d for d in cursor][::-1] # type: ignore ###################### # Check Query Cache # diff --git a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py index 951f562..990b246 100644 --- a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py @@ -1,4 +1,4 @@ -from typing import Any, List, Union +from typing import Any, List, Optional, Union from arango.database import StandardDatabase from langchain_core.chat_history import BaseChatMessageHistory @@ -138,6 +138,25 @@ def messages(self, messages: List[BaseMessage]) -> None: " Use the 'add_messages' instead." ) + def get_messages(self, role: Optional[str] = None, n_messages: int = 10) -> list: + """Retrieve messages from ArangoDB, optionally filtered by role.""" + query = """ + FOR doc IN @@col + FILTER doc.session_id == @session_id + FILTER @role == null || doc.role == @role + SORT doc._key DESC + LIMIT @n + RETURN UNSET(doc, ["_id", "_key", "_rev", "session_id"]) + """ + bind_vars = { + "@col": self._collection_name, + "session_id": self._session_id, + "role": role, + "n": n_messages, + } + cursor = self._db.aql.execute(query, bind_vars=bind_vars) # type: ignore + return [d for d in cursor][::-1] # type: ignore + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in ArangoDB. @@ -178,10 +197,11 @@ def add_qa_message(self, user_input: str, aql_query: str, result: str) -> None: self._db.collection(self._collection_name).insert( { "role": "qa", + "session_id": self._session_id, "user_input": user_input, "aql_query": aql_query, "result": result, - }, + } ) def clear(self) -> None: diff --git a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py index 531523b..f4f3634 100644 --- a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py @@ -662,20 +662,22 @@ def test_chat_history( result="2", ) - fake_graph_store.db.aql.execute.return_value = iter( - [ - { - "user_input": "What is 1+1?", - "aql_query": "RETURN 1+1", - "result": "2", - }, - { - "user_input": "What is 2+2?", - "aql_query": "RETURN 2+2", - "result": "4", - }, - ] - ) + chat_history_store.get_messages.return_value = [ + { + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + "role": "qa", + "session_id": "test", + }, + { + "user_input": "What is 2+2?", + "aql_query": "RETURN 2+2", + "result": "4", + "role": "qa", + "session_id": "test", + }, + ] # Mock LLM chains mock_chains[ # type: ignore From 7fa0cf00225702923e74906034336f058b4eb801 Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Thu, 28 Aug 2025 14:50:50 -0700 Subject: [PATCH 08/10] optimize get_messages & add time property --- .../chat_message_histories/arangodb.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py index 990b246..1628216 100644 --- a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py @@ -1,3 +1,4 @@ +import time from typing import Any, List, Optional, Union from arango.database import StandardDatabase @@ -138,23 +139,32 @@ def messages(self, messages: List[BaseMessage]) -> None: " Use the 'add_messages' instead." ) - def get_messages(self, role: Optional[str] = None, n_messages: int = 10) -> list: + def get_messages( + self, + role: Optional[str] = None, + n_messages: int = 10, + excluded_fields: list[str] = ["_id", "_key", "_rev", "session_id", "time"], + ) -> list: """Retrieve messages from ArangoDB, optionally filtered by role.""" - query = """ + query = f""" FOR doc IN @@col FILTER doc.session_id == @session_id - FILTER @role == null || doc.role == @role - SORT doc._key DESC + {"AND doc.role == @role" if role else ""} + SORT doc.time DESC LIMIT @n - RETURN UNSET(doc, ["_id", "_key", "_rev", "session_id"]) + RETURN UNSET(doc, @excluded_fields) """ bind_vars = { "@col": self._collection_name, "session_id": self._session_id, - "role": role, "n": n_messages, + "excluded_fields": excluded_fields, } + if role is not None: + bind_vars["role"] = role cursor = self._db.aql.execute(query, bind_vars=bind_vars) # type: ignore + + # return in chronological order return [d for d in cursor][::-1] # type: ignore def add_message(self, message: BaseMessage) -> None: @@ -189,6 +199,7 @@ def add_message(self, message: BaseMessage) -> None: "role": message.type, "content": message.content, "session_id": self._session_id, + "time": time.time(), }, ) @@ -198,6 +209,7 @@ def add_qa_message(self, user_input: str, aql_query: str, result: str) -> None: { "role": "qa", "session_id": self._session_id, + "time": time.time(), "user_input": user_input, "aql_query": aql_query, "result": result, From 45a5bb6a265b3c12e59be27d29d1b8aed840a67b Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Thu, 28 Aug 2025 14:51:32 -0700 Subject: [PATCH 09/10] test get_messages & add_qa_message --- .../chat_message_histories/test_arangodb.py | 50 ++++++++-- .../test_arangodb_chat_message_history.py | 97 ++++++++++++++++++- 2 files changed, 136 insertions(+), 11 deletions(-) diff --git a/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py b/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py index 14e8941..40dbec3 100644 --- a/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py +++ b/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py @@ -30,18 +30,18 @@ def test_add_messages(db: StandardDatabase) -> None: # Now check if the messages are stored in the database correctly assert len(message_store.messages) == 2 - assert isinstance(message_store.messages[0], HumanMessage) - assert isinstance(message_store.messages[1], AIMessage) - assert message_store.messages[0].content == "Hello! Language Chain!" - assert message_store.messages[1].content == "Hi Guys!" + assert isinstance(message_store.messages[0], AIMessage) + assert isinstance(message_store.messages[1], HumanMessage) + assert message_store.messages[0].content == "Hi Guys!" + assert message_store.messages[1].content == "Hello! Language Chain!" assert len(message_store_another.messages) == 3 assert isinstance(message_store_another.messages[0], HumanMessage) assert isinstance(message_store_another.messages[1], AIMessage) assert isinstance(message_store_another.messages[2], HumanMessage) - assert message_store_another.messages[0].content == "Hello! Bot!" + assert message_store_another.messages[0].content == "How's this pr going?" assert message_store_another.messages[1].content == "Hi there!" - assert message_store_another.messages[2].content == "How's this pr going?" + assert message_store_another.messages[2].content == "Hello! Bot!" # Now clear the first history message_store.clear() @@ -108,10 +108,10 @@ def test_arangodb_message_history_clear_messages( ] ) assert len(message_history.messages) == 2 - assert isinstance(message_history.messages[0], HumanMessage) - assert isinstance(message_history.messages[1], AIMessage) - assert message_history.messages[0].content == "You are a helpful assistant." - assert message_history.messages[1].content == "Hello" + assert isinstance(message_history.messages[0], AIMessage) + assert isinstance(message_history.messages[1], HumanMessage) + assert message_history.messages[0].content == "Hello" + assert message_history.messages[1].content == "You are a helpful assistant." message_history.clear() assert len(message_history.messages) == 0 @@ -155,3 +155,33 @@ def test_arangodb_message_history_clear_session_collection( # Delete the collection (equivalent to delete_session_node in Neo4j) db.delete_collection(collection_name) assert not db.has_collection(collection_name) + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_add_and_get_messages(db: StandardDatabase) -> None: + """Test adding a QA message to the collection.""" + message_history = ArangoChatMessageHistory(session_id="123", db=db) + message_history.add_qa_message( + user_input="What is 1+1?", + aql_query="RETURN 1+1", + result="2", + ) + message_history.add_messages( + [ + HumanMessage(content="You are a helpful assistant."), + AIMessage(content="Hello"), + ] + ) + all_messages = message_history.get_messages() + assert len(all_messages) == 3 + assert all_messages[0]["user_input"] == "What is 1+1?" + assert all_messages[0]["aql_query"] == "RETURN 1+1" + assert all_messages[0]["result"] == "2" + assert all_messages[1]["content"] == "You are a helpful assistant." + assert all_messages[2]["content"] == "Hello" + + qa_messages = message_history.get_messages(role="qa") + assert len(qa_messages) == 1 + assert qa_messages[0]["user_input"] == "What is 1+1?" + assert qa_messages[0]["aql_query"] == "RETURN 1+1" + assert qa_messages[0]["result"] == "2" diff --git a/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py b/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py index 28592b4..137066a 100644 --- a/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py +++ b/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest from arango.database import StandardDatabase @@ -129,6 +129,7 @@ def test_add_message() -> None: "role": "human", "content": "Hello, world!", "session_id": "test_session", + "time": ANY, } ) @@ -198,3 +199,97 @@ def test_messages_property() -> None: assert messages[0].content == "Hello" assert messages[1].type == "ai" assert messages[1].content == "Hi there" + + +def test_add_qa_message() -> None: + """Test adding a QA message to the collection.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.has_collection.return_value = True + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Add the message + message_store.add_qa_message( + user_input="What is 1+1?", + aql_query="RETURN 1+1", + result="2", + ) + + # Verify the message was added to the collection + mock_db.collection.assert_called_with("ChatHistory") + mock_collection.insert.assert_called_once_with( + { + "role": "qa", + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + "session_id": "test_session", + "time": ANY, + } + ) + + +def test_get_messages() -> None: + """Test retrieving messages from the collection.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_aql = MagicMock() + mock_cursor = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.aql = mock_aql + mock_db.has_collection.return_value = True + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + mock_aql.execute.return_value = mock_cursor + + rows = [ + {"role": "human", "content": "Hello", "time": 1}, + {"role": "ai", "content": "Hi there", "time": 2}, + { + "role": "qa", + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + "time": 3, + }, + ] + + def execute_side_effect(query=None, bind_vars=None, **kwargs): # type: ignore + role = (bind_vars or {}).get("role") + filtered = [r for r in rows if role is None or r["role"] == role] + # mimic: SORT doc.time DESC + return sorted(filtered, key=lambda r: r["time"], reverse=True) + + mock_aql.execute.side_effect = execute_side_effect + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Get the messages + human_messages = message_store.get_messages(role="human") + ai_messages = message_store.get_messages(role="ai") + qa_messages = message_store.get_messages(role="qa") + + assert len(human_messages) == 1 + assert len(ai_messages) == 1 + assert len(qa_messages) == 1 + assert human_messages[0]["content"] == "Hello" + assert ai_messages[0]["content"] == "Hi there" + assert qa_messages[0]["user_input"] == "What is 1+1?" + assert qa_messages[0]["aql_query"] == "RETURN 1+1" + assert qa_messages[0]["result"] == "2" + + all_messages = message_store.get_messages() + assert len(all_messages) == 3 + assert all_messages[0]["content"] == "Hello" + assert all_messages[1]["content"] == "Hi there" + assert all_messages[2]["user_input"] == "What is 1+1?" + assert all_messages[2]["aql_query"] == "RETURN 1+1" + assert all_messages[2]["result"] == "2" From 9aa543e97ffb78aa88f604e30b3abccbcd13c146 Mon Sep 17 00:00:00 2001 From: MonikaLiu Date: Fri, 29 Aug 2025 12:58:35 -0700 Subject: [PATCH 10/10] add docstring to funcs --- .../chat_message_histories/arangodb.py | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py index 1628216..9ac67ec 100644 --- a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py @@ -48,10 +48,22 @@ class ArangoChatMessageHistory(BaseChatMessageHistory): history.add_user_message("Hello! How are you?") history.add_ai_message("I'm doing well, thank you!") + # Add QA message + history.add_qa_message( + user_input="Who is the first character?", + aql_query="FOR doc IN Characters LIMIT 1 RETURN doc", + result="The first character is Arya Stark." + ) + # Retrieve messages messages = history.messages print(f"Found {len(messages)} messages") + # Retrieve messages by role + human_messages = history.get_messages(role="human") + ai_messages = history.get_messages(role="ai") + qa_messages = history.get_messages(role="qa") + # Clear session history.clear() """ @@ -145,7 +157,27 @@ def get_messages( n_messages: int = 10, excluded_fields: list[str] = ["_id", "_key", "_rev", "session_id", "time"], ) -> list: - """Retrieve messages from ArangoDB, optionally filtered by role.""" + """Retrieve messages from ArangoDB, optionally filtered by role. + + :param role: Optional filter to retrieve messages of a specific role. + :type role: Optional[str] + :param n_messages: Number of messages to retrieve. + :type n_messages: int + :param excluded_fields: Fields to exclude from the returned messages. + :type excluded_fields: list[str] + + .. code-block:: python + + # Get all types of messages, default is 10 messages + messages = history.get_messages() + + # Get the first 20 human messages + messages = history.get_messages(role="human", n_messages=20) + + # Get the first 20 AI messages + messages = history.get_messages(role="ai", n_messages=20) + + """ query = f""" FOR doc IN @@col FILTER doc.session_id == @session_id @@ -204,7 +236,23 @@ def add_message(self, message: BaseMessage) -> None: ) def add_qa_message(self, user_input: str, aql_query: str, result: str) -> None: - """Add a QA message to the chat history.""" + """Add a QA message to the chat history. + + :param user_input: The user's input. + :type user_input: str + :param aql_query: The AQL query to execute. + :type aql_query: str + :param result: The result of the AQL query. + :type result: str + + .. code-block:: python + + history.add_qa_message( + user_input="Who is the first character?", + aql_query="FOR doc IN Characters LIMIT 1 RETURN doc", + result="The first character is Arya Stark." + ) + """ self._db.collection(self._collection_name).insert( { "role": "qa",