Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
20dc942
successfully implement exact/vector search
anyxling Jul 26, 2025
8a9309d
add unit tests
anyxling Jul 26, 2025
5cf9f76
improve error handling
anyxling Jul 26, 2025
aa0008b
add integration tests for query caching
anyxling Jul 26, 2025
5a34484
remove set_trace()
anyxling Jul 26, 2025
1c93b9a
fix AI message error
anyxling Jul 26, 2025
8e9a7c1
fix integration tests to pass lint tests
anyxling Jul 27, 2025
749b62b
fix aql_gen_count error in unit tests and pass lint tests
anyxling Jul 27, 2025
69c3e5a
fix lint errors
anyxling Jul 27, 2025
8ef18a0
add langchain-openai
anyxling Jul 27, 2025
572b103
update poetry.lock
anyxling Jul 27, 2025
8546a75
upgrade ruff
anyxling Jul 27, 2025
8f620c6
add documentation
anyxling Jul 28, 2025
9406956
remove cast and reformat
anyxling Jul 28, 2025
8e00777
change to Embeddings and remove unnecessary args
anyxling Jul 29, 2025
c9c0fce
remove langchain-openai and update ruff version
anyxling Jul 29, 2025
b7c0a81
simplify integration tests
anyxling Jul 29, 2025
738e307
add score to the output
anyxling Jul 29, 2025
bf0c2fd
change to insert_many() and invoke()
anyxling Jul 29, 2025
d2be9b3
change to invoke() and revert result changes
anyxling Jul 29, 2025
2821eaa
refactor _call
anyxling Jul 30, 2025
7657e16
add cache_query()
anyxling Jul 30, 2025
7fa4a7b
Merge branch 'main' into monika-querycache
anyxling Jul 30, 2025
58c1674
customize clear_query_cache
anyxling Jul 30, 2025
7b9689c
handle edge cases for cache_query
anyxling Jul 31, 2025
70616d7
normalize user input text, query
anyxling Jul 31, 2025
0fa52e9
handle vector search in mock_execute
anyxling Jul 31, 2025
30bcb95
fix lint and format err
anyxling Jul 31, 2025
f699dbb
move: AQL query print
aMahanna Aug 4, 2025
3d94f09
misc: introduce `assert`, `self.graph._hash`, minor cleanup
aMahanna Aug 4, 2025
2ccc061
replace with hashed key and remove _format_aql
anyxling Aug 5, 2025
b728fc5
update integration test
anyxling Aug 5, 2025
d926aaa
move back aql_execution_func & params
anyxling Aug 5, 2025
45c2531
add tests for new funcs
anyxling Aug 5, 2025
7fb3895
rename __get_cached_query and fix lint err
anyxling Aug 5, 2025
b5f0701
remove redundant test
anyxling Aug 6, 2025
098b5bf
minor cleanup
aMahanna Aug 7, 2025
a04c336
cleanup PT2
aMahanna Aug 7, 2025
efb2b16
fix: `db`
aMahanna Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified libs/arangodb/.coverage
Binary file not shown.
239 changes: 223 additions & 16 deletions libs/arangodb/langchain_arangodb/chains/graph_qa/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from arango import AQLQueryExecuteError, AQLQueryExplainError
from langchain.chains.base import Chain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage
from langchain_core.prompts import BasePromptTemplate
Expand All @@ -19,7 +20,7 @@
AQL_GENERATION_PROMPT,
AQL_QA_PROMPT,
)
from langchain_arangodb.graphs.graph_store import GraphStore
from langchain_arangodb.graphs.arangodb_graph import ArangoGraph

AQL_WRITE_OPERATIONS: List[str] = [
"INSERT",
Expand All @@ -45,7 +46,9 @@ class ArangoGraphQAChain(Chain):
See https://python.langchain.com/docs/security for more information.
"""

graph: GraphStore = Field(exclude=True)
graph: ArangoGraph = Field(exclude=True)
embedding: Optional[Embeddings] = Field(default=None, exclude=True)
query_cache_collection_name: str = Field(default="Queries")
aql_generation_chain: Runnable[Dict[str, Any], Any]
aql_fix_chain: Runnable[Dict[str, Any], Any]
qa_chain: Runnable[Dict[str, Any], Any]
Expand Down Expand Up @@ -102,6 +105,8 @@ def __init__(self, **kwargs: Any) -> None:
"necessary precautions. "
"See https://python.langchain.com/docs/security for more information."
)
self._last_user_input: Optional[str] = None
self._last_aql_query: Optional[str] = None

@property
def input_keys(self) -> List[str]:
Expand Down Expand Up @@ -132,6 +137,11 @@ def from_llm(

:param llm: The language model to use.
:type llm: BaseLanguageModel
:param embedding: The embedding model to use.
:type embedding: Embeddings
:param query_cache_collection_name: The name of the collection
to use for the query cache.
:type query_cache_collection_name: str
:param qa_prompt: The prompt to use for the QA chain.
:type qa_prompt: BasePromptTemplate
:param aql_generation_prompt: The prompt to use for the AQL generation chain.
Expand Down Expand Up @@ -162,6 +172,170 @@ def from_llm(
**kwargs,
)

def _check_and_insert_query(self, text: str, aql: str) -> str:
"""
Check if a query is already in the cache and insert it if it's not.

:param text: The text of the query to check.
:type text: str
:param aql: The AQL query to check.
:type aql: str
:return: A message indicating the result of the operation.
"""
text = text.strip().lower()
text_hash = self.graph._hash(text)
collection = self.graph.db.collection(self.query_cache_collection_name)

if collection.has(text_hash):
return f"This query is already in the cache: {text}"

if self.embedding is None:
raise ValueError("Cannot cache queries without an embedding model.")

query_embedding = self.embedding.embed_query(text)
collection.insert(
{
"_key": text_hash,
"text": text,
"embedding": query_embedding,
"aql": aql,
}
)

return f"Cached: {text}"

def cache_query(self, text: Optional[str] = None, aql: Optional[str] = None) -> str:
"""
Cache a query generated by the LLM only if it's not already stored.

:param text: The text of the query to cache.
:param aql: The AQL query to cache.
:return: A message indicating the result of the operation.
"""
if self.embedding is None:
raise ValueError("Cannot cache queries without an embedding model.")

if not self.graph.db.has_collection(self.query_cache_collection_name):
m = f"Collection {self.query_cache_collection_name} does not exist" # noqa: E501
raise ValueError(m)

if not text and aql:
raise ValueError("Text is required to cache a query")

if text and not aql:
raise ValueError("AQL is required to cache a query")

if not text and not aql:
if self._last_user_input is None or self._last_aql_query is None:
m = "No previous query to cache. Please provide **text** and **aql**."
raise ValueError(m)

# Fallback: cache the most recent query
return self._check_and_insert_query(
self._last_user_input,
self._last_aql_query,
)

if not isinstance(text, str) or not isinstance(aql, str):
raise ValueError("Text and AQL must be strings")

return self._check_and_insert_query(text, aql)

def clear_query_cache(self, text: Optional[str] = None) -> str:
"""
Clear the query cache.

:param text: The text of the query to delete from the cache.
:type text: str
:return: A message indicating the result of the operation.
"""

if not self.graph.db.has_collection(self.query_cache_collection_name):
m = f"Collection {self.query_cache_collection_name} does not exist"
raise ValueError(m)

collection = self.graph.db.collection(self.query_cache_collection_name)

if text is None:
collection.truncate()
return "Cleared all queries from the cache"

text = text.strip().lower()
text_hash = self.graph._hash(text)

if collection.has(text_hash):
collection.delete(text_hash)
return f"Removed: {text}"

return f"Not found: {text}"

def _get_cached_query(
self, user_input: str, query_cache_similarity_threshold: float
) -> Optional[Tuple[str, str]]:
"""Get the cached query for the user input. Only used if embedding
is provided and **use_query_cache** is True.

:param user_input: The user input to search for in the cache.
:type user_input: str

:return: The cached query and score, if found.
:rtype: Optional[Tuple[str, int]]
"""
if self.embedding is None:
raise ValueError("Cannot enable query cache without passing embedding")

if self.graph.db.collection(self.query_cache_collection_name).count() == 0:
return None

user_input = user_input.strip().lower()

# 1. Exact Search

query = f"""
FOR q IN {self.query_cache_collection_name}
FILTER q.text == @user_input
LIMIT 1
RETURN q.aql
"""

cursor = self.graph.db.aql.execute(
query,
bind_vars={"user_input": user_input},
)

result = list(cursor) # type: ignore

if result:
return result[0], "1.0"

# 2. Vector Search

embedding = self.embedding.embed_query(user_input)
query = """
FOR q IN @@col
LET score = COSINE_SIMILARITY(q.embedding, @embedding)
SORT score DESC
LIMIT 1
FILTER score > @score_threshold
RETURN {aql: q.aql, score: score}
"""

result = list(
self.graph.db.aql.execute(
query,
bind_vars={
"@col": self.query_cache_collection_name,
"embedding": embedding, # type: ignore
"score_threshold": query_cache_similarity_threshold, # type: ignore
},
)
)

if result:
return result[0]["aql"], str(round(result[0]["score"], 2))

return None

def _call(
self,
inputs: Dict[str, Any],
Expand Down Expand Up @@ -211,20 +385,46 @@ def _call(
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
user_input = inputs[self.input_key]
user_input = inputs[self.input_key].strip().lower()
use_query_cache = inputs.get("use_query_cache", False)
query_cache_similarity_threshold = inputs.get(
"query_cache_similarity_threshold", 0.80
)

if use_query_cache and self.embedding is None:
raise ValueError("Cannot enable query cache without passing embedding")

######################
# Generate AQL Query #
# Check Query Cache #
######################

aql_generation_output = self.aql_generation_chain.invoke(
{
"adb_schema": self.graph.schema_yaml,
"aql_examples": self.aql_examples,
"user_input": user_input,
},
callbacks=callbacks,
)
cached_query, score = None, None
if use_query_cache:
if self.embedding is None:
m = "Embedding must be provided when using query cache"
raise ValueError(m)

if not self.graph.db.has_collection(self.query_cache_collection_name):
self.graph.db.create_collection(self.query_cache_collection_name)

cache_result = self._get_cached_query(
user_input, query_cache_similarity_threshold
)

if cache_result is not None:
cached_query, score = cache_result

if cached_query:
aql_generation_output = f"```aql{cached_query}```"
else:
aql_generation_output = self.aql_generation_chain.invoke(
{
"adb_schema": self.graph.schema_yaml,
"aql_examples": self.aql_examples,
"user_input": user_input,
},
callbacks=callbacks,
)

aql_query = ""
aql_error = ""
Expand Down Expand Up @@ -283,9 +483,14 @@ def _call(
"""
raise ValueError(error_msg)

_run_manager.on_text(
f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose
)
query_message = f"AQL Query ({aql_generation_attempt})\n"
if cached_query:
score_string = score if score is not None else "1.0"
query_message = (
f"AQL Query (used cached query, score: {score_string})\n" # noqa: E501
)

_run_manager.on_text(query_message, verbose=self.verbose)
_run_manager.on_text(
aql_query, color="green", end="\n", verbose=self.verbose
)
Expand All @@ -300,7 +505,6 @@ def _call(
"list_limit": self.output_list_limit,
"string_limit": self.output_string_limit,
}

aql_result = aql_execution_func(aql_query, params)
except (AQLQueryExecuteError, AQLQueryExplainError) as e:
aql_error = str(e.error_message)
Expand Down Expand Up @@ -368,6 +572,9 @@ def _call(
if self.return_aql_result:
results["aql_result"] = aql_result

self._last_user_input = user_input
self._last_aql_query = aql_query

return results

def _is_read_only_query(self, aql_query: str) -> Tuple[bool, Optional[str]]:
Expand Down
4 changes: 2 additions & 2 deletions libs/arangodb/langchain_arangodb/graphs/arangodb_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def query(
top_k = params.pop("top_k", None)
list_limit = params.pop("list_limit", 32)
string_limit = params.pop("string_limit", 256)
cursor = self.__db.aql.execute(query, **params)
cursor = self.db.aql.execute(query, **params)

results = []

Expand Down Expand Up @@ -454,7 +454,7 @@ def explain(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
:raises ArangoServerError: If the ArangoDB server cannot be reached.
:raises ArangoCollectionError: If the collection cannot be created.
"""
return self.__db.aql.explain(query) # type: ignore
return self.db.aql.explain(query) # type: ignore

def add_graph_documents(
self,
Expand Down
Loading