diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py index b44aaed..11e526b 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py @@ -10,7 +10,7 @@ from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable from pydantic import Field @@ -429,9 +429,9 @@ def _call( chat_history = [] if include_history and self.chat_history_store is not None: - for msg in self.chat_history_store.messages[-max_history_messages:]: - cls = HumanMessage if msg.type == "human" else AIMessage - chat_history.append(cls(content=msg.content)) + chat_history.extend( + self.chat_history_store.get_messages(n_messages=max_history_messages) + ) ###################### # Check Query Cache # @@ -631,10 +631,12 @@ def _call( # Store Chat History # ######################## - if self.chat_history_store: - self.chat_history_store.add_user_message(user_input) - self.chat_history_store.add_ai_message(aql_query) - self.chat_history_store.add_ai_message(content) + if self.chat_history_store is not None: + self.chat_history_store.add_qa_message( + user_input, + aql_query, + result.content if isinstance(result, AIMessage) else result, # type: ignore + ) return results diff --git a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py index 40c64df..45d4e01 100644 --- a/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py +++ b/libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py @@ -15,9 +15,10 @@ Rules for Using Chat History: - If the Chat History is not empty, use it only as a reference to help clarify the current User Input — for example, to resolve pronouns or implicit references. +- If the Chat History entry has a role of "qa" which contains User Input, AQL Query, and AQL Result, use all of them to generate the AQL Query. +- If the Chat History entry has a role of "human", use it as feedback to improve the AQL Query. Do not use it to generate the AQL Query. - Chat History is ordered chronologically. Prioritize latest entries when resolving context or references. - If the Chat History is empty, do not use it or refer to it in any way. Treat the User Input as a fully self-contained and standalone question. -- The Chat History includes the User Input, the AQL Query generated by the AI Model, and the interpertation of AQL Result. Use all of them to generate the AQL Query. Things you should do: - Think step by step. diff --git a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py index 49a7057..9ac67ec 100644 --- a/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py +++ b/libs/arangodb/langchain_arangodb/chat_message_histories/arangodb.py @@ -1,4 +1,5 @@ -from typing import Any, List, Union +import time +from typing import Any, List, Optional, Union from arango.database import StandardDatabase from langchain_core.chat_history import BaseChatMessageHistory @@ -47,10 +48,22 @@ class ArangoChatMessageHistory(BaseChatMessageHistory): history.add_user_message("Hello! How are you?") history.add_ai_message("I'm doing well, thank you!") + # Add QA message + history.add_qa_message( + user_input="Who is the first character?", + aql_query="FOR doc IN Characters LIMIT 1 RETURN doc", + result="The first character is Arya Stark." + ) + # Retrieve messages messages = history.messages print(f"Found {len(messages)} messages") + # Retrieve messages by role + human_messages = history.get_messages(role="human") + ai_messages = history.get_messages(role="ai") + qa_messages = history.get_messages(role="qa") + # Clear session history.clear() """ @@ -138,6 +151,54 @@ def messages(self, messages: List[BaseMessage]) -> None: " Use the 'add_messages' instead." ) + def get_messages( + self, + role: Optional[str] = None, + n_messages: int = 10, + excluded_fields: list[str] = ["_id", "_key", "_rev", "session_id", "time"], + ) -> list: + """Retrieve messages from ArangoDB, optionally filtered by role. + + :param role: Optional filter to retrieve messages of a specific role. + :type role: Optional[str] + :param n_messages: Number of messages to retrieve. + :type n_messages: int + :param excluded_fields: Fields to exclude from the returned messages. + :type excluded_fields: list[str] + + .. code-block:: python + + # Get all types of messages, default is 10 messages + messages = history.get_messages() + + # Get the first 20 human messages + messages = history.get_messages(role="human", n_messages=20) + + # Get the first 20 AI messages + messages = history.get_messages(role="ai", n_messages=20) + + """ + query = f""" + FOR doc IN @@col + FILTER doc.session_id == @session_id + {"AND doc.role == @role" if role else ""} + SORT doc.time DESC + LIMIT @n + RETURN UNSET(doc, @excluded_fields) + """ + bind_vars = { + "@col": self._collection_name, + "session_id": self._session_id, + "n": n_messages, + "excluded_fields": excluded_fields, + } + if role is not None: + bind_vars["role"] = role + cursor = self._db.aql.execute(query, bind_vars=bind_vars) # type: ignore + + # return in chronological order + return [d for d in cursor][::-1] # type: ignore + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in ArangoDB. @@ -170,9 +231,39 @@ def add_message(self, message: BaseMessage) -> None: "role": message.type, "content": message.content, "session_id": self._session_id, + "time": time.time(), }, ) + def add_qa_message(self, user_input: str, aql_query: str, result: str) -> None: + """Add a QA message to the chat history. + + :param user_input: The user's input. + :type user_input: str + :param aql_query: The AQL query to execute. + :type aql_query: str + :param result: The result of the AQL query. + :type result: str + + .. code-block:: python + + history.add_qa_message( + user_input="Who is the first character?", + aql_query="FOR doc IN Characters LIMIT 1 RETURN doc", + result="The first character is Arya Stark." + ) + """ + self._db.collection(self._collection_name).insert( + { + "role": "qa", + "session_id": self._session_id, + "time": time.time(), + "user_input": user_input, + "aql_query": aql_query, + "result": result, + } + ) + def clear(self) -> None: """Clear session memory from ArangoDB. diff --git a/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py b/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py index 14e8941..40dbec3 100644 --- a/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py +++ b/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py @@ -30,18 +30,18 @@ def test_add_messages(db: StandardDatabase) -> None: # Now check if the messages are stored in the database correctly assert len(message_store.messages) == 2 - assert isinstance(message_store.messages[0], HumanMessage) - assert isinstance(message_store.messages[1], AIMessage) - assert message_store.messages[0].content == "Hello! Language Chain!" - assert message_store.messages[1].content == "Hi Guys!" + assert isinstance(message_store.messages[0], AIMessage) + assert isinstance(message_store.messages[1], HumanMessage) + assert message_store.messages[0].content == "Hi Guys!" + assert message_store.messages[1].content == "Hello! Language Chain!" assert len(message_store_another.messages) == 3 assert isinstance(message_store_another.messages[0], HumanMessage) assert isinstance(message_store_another.messages[1], AIMessage) assert isinstance(message_store_another.messages[2], HumanMessage) - assert message_store_another.messages[0].content == "Hello! Bot!" + assert message_store_another.messages[0].content == "How's this pr going?" assert message_store_another.messages[1].content == "Hi there!" - assert message_store_another.messages[2].content == "How's this pr going?" + assert message_store_another.messages[2].content == "Hello! Bot!" # Now clear the first history message_store.clear() @@ -108,10 +108,10 @@ def test_arangodb_message_history_clear_messages( ] ) assert len(message_history.messages) == 2 - assert isinstance(message_history.messages[0], HumanMessage) - assert isinstance(message_history.messages[1], AIMessage) - assert message_history.messages[0].content == "You are a helpful assistant." - assert message_history.messages[1].content == "Hello" + assert isinstance(message_history.messages[0], AIMessage) + assert isinstance(message_history.messages[1], HumanMessage) + assert message_history.messages[0].content == "Hello" + assert message_history.messages[1].content == "You are a helpful assistant." message_history.clear() assert len(message_history.messages) == 0 @@ -155,3 +155,33 @@ def test_arangodb_message_history_clear_session_collection( # Delete the collection (equivalent to delete_session_node in Neo4j) db.delete_collection(collection_name) assert not db.has_collection(collection_name) + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_add_and_get_messages(db: StandardDatabase) -> None: + """Test adding a QA message to the collection.""" + message_history = ArangoChatMessageHistory(session_id="123", db=db) + message_history.add_qa_message( + user_input="What is 1+1?", + aql_query="RETURN 1+1", + result="2", + ) + message_history.add_messages( + [ + HumanMessage(content="You are a helpful assistant."), + AIMessage(content="Hello"), + ] + ) + all_messages = message_history.get_messages() + assert len(all_messages) == 3 + assert all_messages[0]["user_input"] == "What is 1+1?" + assert all_messages[0]["aql_query"] == "RETURN 1+1" + assert all_messages[0]["result"] == "2" + assert all_messages[1]["content"] == "You are a helpful assistant." + assert all_messages[2]["content"] == "Hello" + + qa_messages = message_history.get_messages(role="qa") + assert len(qa_messages) == 1 + assert qa_messages[0]["user_input"] == "What is 1+1?" + assert qa_messages[0]["aql_query"] == "RETURN 1+1" + assert qa_messages[0]["result"] == "2" diff --git a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py index 89944df..f4f3634 100644 --- a/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/arangodb/tests/unit_tests/chains/test_graph_qa.py @@ -653,12 +653,30 @@ def test_chat_history( chat_history_store = Mock(spec=ArangoChatMessageHistory) - # Add fake message history (as objects, not dicts) - chat_history_store.messages = [ - Mock(type="human", content="What is 1+1?"), - Mock(type="ai", content="2"), - Mock(type="human", content="What is 2+2?"), - Mock(type="ai", content="4"), + chat_history_store._collection_name = "ChatHistory" + + # Add fake message history + chat_history_store.add_qa_message( + user_input="What is 1+1?", + aql_query="RETURN 1+1", + result="2", + ) + + chat_history_store.get_messages.return_value = [ + { + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + "role": "qa", + "session_id": "test", + }, + { + "user_input": "What is 2+2?", + "aql_query": "RETURN 2+2", + "result": "4", + "role": "qa", + "session_id": "test", + }, ] # Mock LLM chains @@ -685,10 +703,10 @@ def test_chat_history( # Run the call result = chain.invoke({"query": "List all movies"}) - # LLM received the latest 2 pairs (4 messages) + # LLM received the latest 2 docs llm_input = mock_chains["aql_generation_chain"].invoke.call_args[0][0] # type: ignore chat_history = llm_input["chat_history"] - assert len(chat_history) == 4 + assert len(chat_history) == 2 # result has expected fields assert result["result"].content == "Here are the movies." diff --git a/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py b/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py index 28592b4..137066a 100644 --- a/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py +++ b/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest from arango.database import StandardDatabase @@ -129,6 +129,7 @@ def test_add_message() -> None: "role": "human", "content": "Hello, world!", "session_id": "test_session", + "time": ANY, } ) @@ -198,3 +199,97 @@ def test_messages_property() -> None: assert messages[0].content == "Hello" assert messages[1].type == "ai" assert messages[1].content == "Hi there" + + +def test_add_qa_message() -> None: + """Test adding a QA message to the collection.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.has_collection.return_value = True + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Add the message + message_store.add_qa_message( + user_input="What is 1+1?", + aql_query="RETURN 1+1", + result="2", + ) + + # Verify the message was added to the collection + mock_db.collection.assert_called_with("ChatHistory") + mock_collection.insert.assert_called_once_with( + { + "role": "qa", + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + "session_id": "test_session", + "time": ANY, + } + ) + + +def test_get_messages() -> None: + """Test retrieving messages from the collection.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_aql = MagicMock() + mock_cursor = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.aql = mock_aql + mock_db.has_collection.return_value = True + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + mock_aql.execute.return_value = mock_cursor + + rows = [ + {"role": "human", "content": "Hello", "time": 1}, + {"role": "ai", "content": "Hi there", "time": 2}, + { + "role": "qa", + "user_input": "What is 1+1?", + "aql_query": "RETURN 1+1", + "result": "2", + "time": 3, + }, + ] + + def execute_side_effect(query=None, bind_vars=None, **kwargs): # type: ignore + role = (bind_vars or {}).get("role") + filtered = [r for r in rows if role is None or r["role"] == role] + # mimic: SORT doc.time DESC + return sorted(filtered, key=lambda r: r["time"], reverse=True) + + mock_aql.execute.side_effect = execute_side_effect + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Get the messages + human_messages = message_store.get_messages(role="human") + ai_messages = message_store.get_messages(role="ai") + qa_messages = message_store.get_messages(role="qa") + + assert len(human_messages) == 1 + assert len(ai_messages) == 1 + assert len(qa_messages) == 1 + assert human_messages[0]["content"] == "Hello" + assert ai_messages[0]["content"] == "Hi there" + assert qa_messages[0]["user_input"] == "What is 1+1?" + assert qa_messages[0]["aql_query"] == "RETURN 1+1" + assert qa_messages[0]["result"] == "2" + + all_messages = message_store.get_messages() + assert len(all_messages) == 3 + assert all_messages[0]["content"] == "Hello" + assert all_messages[1]["content"] == "Hi there" + assert all_messages[2]["user_input"] == "What is 1+1?" + assert all_messages[2]["aql_query"] == "RETURN 1+1" + assert all_messages[2]["result"] == "2"