Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
59b6cfd
successfully implement exact/vector search
anyxling Jul 26, 2025
851b7bc
add unit tests
anyxling Jul 26, 2025
832f852
improve error handling
anyxling Jul 26, 2025
0a6c5ba
add integration tests for query caching
anyxling Jul 26, 2025
2942063
remove set_trace()
anyxling Jul 26, 2025
5178c0b
fix AI message error
anyxling Jul 26, 2025
128ba8f
fix integration tests to pass lint tests
anyxling Jul 27, 2025
d77e5c4
fix aql_gen_count error in unit tests and pass lint tests
anyxling Jul 27, 2025
e0bb02d
fix lint errors
anyxling Jul 27, 2025
1d4d8e8
add langchain-openai
anyxling Jul 27, 2025
f63f19a
update poetry.lock
anyxling Jul 27, 2025
175f233
add documentation
anyxling Jul 28, 2025
e8d30ea
remove cast and reformat
anyxling Jul 28, 2025
575de90
change to Embeddings and remove unnecessary args
anyxling Jul 29, 2025
e8e0380
remove langchain-openai and update ruff version
anyxling Jul 29, 2025
80109ab
simplify integration tests
anyxling Jul 29, 2025
38be0c0
add score to the output
anyxling Jul 29, 2025
7906a1c
change to insert_many() and invoke()
anyxling Jul 29, 2025
78cca90
change to invoke() and revert result changes
anyxling Jul 29, 2025
3860945
refactor _call
anyxling Jul 30, 2025
a163ebd
customize clear_query_cache
anyxling Jul 30, 2025
4490482
handle edge cases for cache_query
anyxling Jul 31, 2025
630251f
normalize user input text, query
anyxling Jul 31, 2025
677ff38
move: AQL query print
aMahanna Aug 4, 2025
087a1fd
misc: introduce `assert`, `self.graph._hash`, minor cleanup
aMahanna Aug 4, 2025
2eeb919
replace with hashed key and remove _format_aql
anyxling Aug 5, 2025
b6e7e26
add chat history
anyxling Aug 5, 2025
1f2d16a
emphasize chat history in prompt
anyxling Aug 5, 2025
f8601ab
change chat history format & add summary to the output
anyxling Aug 5, 2025
9d46e76
handle edge cases & add documentation
anyxling Aug 5, 2025
6b13a7d
fix inbound/outbound, sort, pronounce issues
anyxling Aug 5, 2025
d0a1a94
add integration test for chat history
anyxling Aug 5, 2025
1eab05a
move back aql_execution_func & params
anyxling Aug 5, 2025
83eb040
rename __get_cached_query and fix lint err
anyxling Aug 5, 2025
be709f7
add unit test for chat history
anyxling Aug 6, 2025
7fa7f4a
format & lint
anyxling Aug 6, 2025
f17db2e
sync with pre-rebase state
anyxling Aug 11, 2025
a52ba42
new: parameter override at runtime
aMahanna Aug 12, 2025
c287077
instantiate w/ ternary operator
anyxling Aug 12, 2025
f21f517
remove chat history from aql2text
anyxling Aug 12, 2025
b0030f3
simplify printing of summary
anyxling Aug 12, 2025
d7841eb
remove sort in prompt
anyxling Aug 12, 2025
c89df89
Merge branch 'monika-querycache-chat-history' of https://github.com/a…
anyxling Aug 12, 2025
a5778ea
disable history in prompt
anyxling Aug 13, 2025
f7ac176
add aql query to llm input
anyxling Aug 13, 2025
cb934e0
remove type ignore
anyxling Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable
from pydantic import Field
Expand All @@ -20,6 +20,7 @@
AQL_GENERATION_PROMPT,
AQL_QA_PROMPT,
)
from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory
from langchain_arangodb.graphs.arangodb_graph import ArangoGraph

AQL_WRITE_OPERATIONS: List[str] = [
Expand Down Expand Up @@ -54,6 +55,11 @@ class ArangoGraphQAChain(Chain):
qa_chain: Runnable[Dict[str, Any], Any]
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
use_query_cache: bool = Field(default=False)
query_cache_similarity_threshold: float = Field(default=0.80)
include_history: bool = Field(default=False)
max_history_messages: int = Field(default=10)
chat_history_store: Optional[ArangoChatMessageHistory] = Field(default=None)

top_k: int = 10
"""Number of results to return from the query"""
Expand Down Expand Up @@ -142,6 +148,13 @@ def from_llm(
:param query_cache_collection_name: The name of the collection
to use for the query cache.
:type query_cache_collection_name: str
:param include_history: Whether to include the chat history in the prompt.
:type include_history: bool
:param max_history_messages: The maximum number of messages to
include in the chat history.
:type max_history_messages: int
:param chat_history_store: The chat history store to use.
:type chat_history_store: ArangoChatMessageHistory
:param qa_prompt: The prompt to use for the QA chain.
:type qa_prompt: BasePromptTemplate
:param aql_generation_prompt: The prompt to use for the AQL generation chain.
Expand Down Expand Up @@ -386,14 +399,40 @@ def _call(
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
user_input = inputs[self.input_key].strip().lower()
use_query_cache = inputs.get("use_query_cache", False)

# Query Cache Parameters (can be overridden by inputs at runtime)
use_query_cache = inputs.get("use_query_cache", self.use_query_cache)
query_cache_similarity_threshold = inputs.get(
"query_cache_similarity_threshold", 0.80
"query_cache_similarity_threshold", self.query_cache_similarity_threshold
)

# Chat History Parameters (can be overridden by inputs at runtime)
include_history = inputs.get("include_history", self.include_history)
max_history_messages = inputs.get(
"max_history_messages", self.max_history_messages
)

if use_query_cache and self.embedding is None:
raise ValueError("Cannot enable query cache without passing embedding")

# ######################
# # Get Chat History #
# ######################

if include_history and self.chat_history_store is None:
raise ValueError(
"Chat message history is required if include_history is True"
)

if max_history_messages <= 0:
raise ValueError("max_history_messages must be greater than 0")

chat_history = []
if include_history and self.chat_history_store is not None:
for msg in self.chat_history_store.messages[-max_history_messages:]:
cls = HumanMessage if msg.type == "human" else AIMessage
chat_history.append(cls(content=msg.content))

######################
# Check Query Cache #
######################
Expand Down Expand Up @@ -422,6 +461,7 @@ def _call(
"adb_schema": self.graph.schema_yaml,
"aql_examples": self.aql_examples,
"user_input": user_input,
"chat_history": chat_history,
},
callbacks=callbacks,
)
Expand Down Expand Up @@ -564,6 +604,16 @@ def _call(
callbacks=callbacks,
)

# Add summary
text = "Summary:"
_run_manager.on_text(text, end="\n", verbose=self.verbose)
_run_manager.on_text(
str(result.content) if isinstance(result, AIMessage) else result,
color="green",
end="\n",
verbose=self.verbose,
)

results: Dict[str, Any] = {self.output_key: result}

if self.return_aql_query:
Expand All @@ -575,6 +625,15 @@ def _call(
self._last_user_input = user_input
self._last_aql_query = aql_query

########################
# Store Chat History #
########################

if self.chat_history_store:
self.chat_history_store.add_user_message(user_input)
self.chat_history_store.add_ai_message(aql_query)
self.chat_history_store.add_ai_message(result)

return results

def _is_read_only_query(self, aql_query: str) -> Tuple[bool, Optional[str]]:
Expand Down
27 changes: 24 additions & 3 deletions libs/arangodb/langchain_arangodb/chains/graph_qa/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input.

You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query.
You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query. You may also be given a `Chat History` to help you create the `AQL Query`.

You are given an `ArangoDB Schema`. It is a YAML Spec containing:
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
Expand All @@ -13,8 +13,18 @@

You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used.

Rules for Using Chat History:
- If the Chat History is not empty, use it only as a reference to help clarify the current User Input — for example, to resolve pronouns or implicit references.
- Chat History is ordered chronologically. Prioritize latest entries when resolving context or references.
- If the Chat History is empty, do not use it or refer to it in any way. Treat the User Input as a fully self-contained and standalone question.
- The Chat History includes the User Input, the AQL Query generated by the AI Model, and the interpertation of AQL Result. Use all of them to generate the AQL Query.

Things you should do:
- Think step by step.
- When both INBOUND and OUTBOUND traversals are possible for a given edge, be extra careful to select the direction that accurately reflects the intended relationship based on the user input and the edge semantics.
Use OUTBOUND to traverse from _from to _to. Use INBOUND to traverse from _to to _from. Refer to the edge's definition in the schema (e.g., collection names or descriptions) to decide which direction reflects the intended relationship.
- Pay close attention to descriptive references in the User Input — including gendered terms (e.g., father, she), attribute-based descriptions (e.g., young, active, French), and implicit types or categories
(e.g., products over $100, available items) — and, if these correspond to fields in the schema, include appropriate filters in the AQL query (e.g., gender == "male", status == "active", price > 100).
- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query.
- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required.
- If a `View Schema` is defined and contains analyzers for specific fields, prefer using the View with the `SEARCH` and `ANALYZER` clauses instead of a direct collection scan.
Expand All @@ -25,13 +35,19 @@
- If a request is unrelated to generating AQL Query, say that you cannot help the user.

Things you should not do:
- Do not use or refer to Chat History if it is empty.
- Do not assume any previously discussed context, or try to resolve pronouns or references to prior questions if the Chat History is empty.
- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`.
- Do not include any text except the generated AQL Query.
- Do not provide explanations or apologies in your responses.
- Do not generate an AQL Query that removes or deletes any data.
- Do not answer or respond to messages in the Chat History.

Under no circumstance should you generate an AQL Query that deletes any data whatsoever.

Chat History (Optional):
{chat_history}

ArangoDB Schema:
{adb_schema}

Expand All @@ -45,7 +61,7 @@
"""

AQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["adb_schema", "aql_examples", "user_input"],
input_variables=["adb_schema", "aql_examples", "user_input", "chat_history"],
template=AQL_GENERATION_TEMPLATE,
)

Expand Down Expand Up @@ -121,6 +137,11 @@
Summary:
"""
AQL_QA_PROMPT = PromptTemplate(
input_variables=["adb_schema", "user_input", "aql_query", "aql_result"],
input_variables=[
"adb_schema",
"user_input",
"aql_query",
"aql_result",
],
template=AQL_QA_TEMPLATE,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_core.runnables import RunnableLambda

from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory
from langchain_arangodb.graphs.arangodb_graph import ArangoGraph
from tests.llms.fake_llm import FakeLLM

Expand Down Expand Up @@ -1062,3 +1063,68 @@ def test_query_cache(db: StandardDatabase) -> None:
)()
query = chain._get_cached_query("gibberish", 0.99)
assert query is None


@pytest.mark.usefixtures("clear_arangodb_database")
def test_chat_history(db: StandardDatabase) -> None:
"""
Test chat history that enables context-aware query generation.
"""
# 1. Create required collections
graph = ArangoGraph(db)
db.create_collection("Movies")
db.collection("Movies").insert_many(
[
{"_key": "matrix", "title": "The Matrix", "year": 1999},
{"_key": "inception", "title": "Inception", "year": 2010},
]
)
graph.refresh_schema()

# 2. Create chat history store
history = ArangoChatMessageHistory(
session_id="test",
collection_name="test_chat_sessions",
db=db,
)
history.clear()

# 3. Dummy LLM: simulate coreference to "The Matrix"
def dummy_llm(prompt): # type: ignore
if "when was it released" in str(prompt).lower(): # type: ignore
return AIMessage(
content="""```aql
WITH Movies
FOR m IN Movies
FILTER m.title == "The Matrix"
RETURN m.year
```"""
)
return AIMessage(
content="""```aql
WITH Movies
FOR m IN Movies
SORT m._key ASC
LIMIT 1
RETURN m.title
```"""
)

dummy_chain = ArangoGraphQAChain.from_llm(
llm=RunnableLambda(dummy_llm), # type: ignore
graph=graph,
allow_dangerous_requests=True,
include_history=True,
max_history_messages=5,
chat_history_store=history,
return_aql_result=True,
return_aql_query=True,
)

# 4. Ask initial question
result1 = dummy_chain.invoke({"query": "What is the first movie?"})
assert "Inception" in result1["aql_result"]

# 5. Ask follow-up question using pronoun "it"
result2 = dummy_chain.invoke({"query": "When was it released?"})
assert 1999 in result2["aql_result"]
65 changes: 65 additions & 0 deletions libs/arangodb/tests/unit_tests/chains/test_graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain_core.runnables import Runnable, RunnableLambda

from langchain_arangodb.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory
from langchain_arangodb.graphs.arangodb_graph import ArangoGraph
from tests.llms.fake_llm import FakeLLM

Expand Down Expand Up @@ -644,3 +645,67 @@ def mock_execute(query, bind_vars): # type: ignore
# 4. Test with query cache disabled
result4 = chain.invoke({"query": "What is the name of the first movie?"})
assert result4["result"] == "```FOR m IN Movies LIMIT 1 RETURN m```"

def test_chat_history(
self, fake_graph_store: FakeGraphStore, mock_chains: Dict[str, Runnable]
) -> None:
"""test _call with chat history"""

chat_history_store = Mock(spec=ArangoChatMessageHistory)

# Add fake message history (as objects, not dicts)
chat_history_store.messages = [
Mock(type="human", content="What is 1+1?"),
Mock(type="ai", content="2"),
Mock(type="human", content="What is 2+2?"),
Mock(type="ai", content="4"),
]

# Mock LLM chains
mock_chains[ # type: ignore
"aql_generation_chain"
].invoke.return_value = "```aql\nFOR m IN Movies RETURN m\n```" # noqa: E501
mock_chains["qa_chain"].invoke.return_value = AIMessage( # type: ignore
content="Here are the movies."
) # noqa: E501

# Build the chain
chain = ArangoGraphQAChain(
graph=fake_graph_store,
aql_generation_chain=mock_chains["aql_generation_chain"],
aql_fix_chain=mock_chains["aql_fix_chain"],
qa_chain=mock_chains["qa_chain"],
allow_dangerous_requests=True,
include_history=True,
chat_history_store=chat_history_store,
max_history_messages=10,
return_aql_result=True,
)

# Run the call
result = chain.invoke({"query": "List all movies"})

# LLM received the latest 2 pairs (4 messages)
llm_input = mock_chains["aql_generation_chain"].invoke.call_args[0][0] # type: ignore
chat_history = llm_input["chat_history"]
assert len(chat_history) == 4

# result has expected fields
assert result["result"].content == "Here are the movies."
assert result["aql_result"][0]["title"] == "Inception"

# Error: chat history enabled but store is missing
chain.chat_history_store = None
with pytest.raises(
ValueError,
match="Chat message history is required if include_history is True",
):
chain.invoke({"query": "List again"})

# Error: invalid max_history_messages
chain.chat_history_store = chat_history_store
chain.max_history_messages = 0
with pytest.raises(
ValueError, match="max_history_messages must be greater than 0"
):
chain.invoke({"query": "List again"})