Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,31 @@ def create_keyword_index(self) -> None:
if self.retrieve_keyword_index():
return

view_properties = {
"links": {
self.collection_name: {
"analyzers": [self.keyword_analyzer],
"fields": {self.text_field: {"analyzers": [self.keyword_analyzer]}},
}
collection = self.db.collection(self.collection_name)
collection.add_index(
{
"type": "inverted",
"name": self.keyword_index_name,
"fields": [
{"name": self.text_field, "analyzer": self.keyword_analyzer}
],
}
)
view_properties = {
"indexes": [
{"collection": self.collection_name, "index": self.keyword_index_name}
]
}

self.db.create_view(self.keyword_index_name, "arangosearch", view_properties)
self.db.create_view(self.keyword_index_name, "search-alias", view_properties)

def delete_keyword_index(self) -> None:
"""Delete the keyword index from the collection."""
view = self.retrieve_keyword_index()
if view:
self.db.delete_view(self.keyword_index_name)
self.db.collection(self.collection_name).delete_index(
self.keyword_index_name, ignore_missing=True
)

def add_embeddings(
self,
Expand Down Expand Up @@ -338,6 +347,7 @@ def similarity_search(
vector_weight: float = 1.0,
keyword_weight: float = 1.0,
keyword_search_clause: str = "",
metadata_clause: str = "",
**kwargs: Any,
) -> List[Document]:
"""Search for similar documents using vector similarity or hybrid search.
Expand Down Expand Up @@ -367,6 +377,9 @@ def similarity_search(
Only used when search_type is "hybrid". Defaults to 1.0.
keyword_search_clause: Optional AQL filter clause to apply Full Text Search.
If empty, a default search clause will be used.
metadata_clause: Optional AQL clause to return additional metadata once
the top k results are retrieved. If specified, the metadata will be
added to the Document.metadata field.

Returns:
List of Document objects most similar to the query.
Expand All @@ -381,6 +394,7 @@ def similarity_search(
return_fields=return_fields,
use_approx=use_approx,
filter_clause=filter_clause,
metadata_clause=metadata_clause,
)

else:
Expand All @@ -394,6 +408,7 @@ def similarity_search(
vector_weight=vector_weight,
keyword_weight=keyword_weight,
keyword_search_clause=keyword_search_clause,
metadata_clause=metadata_clause,
)

def similarity_search_with_score(
Expand All @@ -408,6 +423,7 @@ def similarity_search_with_score(
vector_weight: float = 1.0,
keyword_weight: float = 1.0,
keyword_search_clause: str = "",
metadata_clause: str = "",
) -> List[tuple[Document, float]]:
"""Search for similar documents and return their similarity scores.

Expand Down Expand Up @@ -449,6 +465,7 @@ def similarity_search_with_score(
return_fields=return_fields,
use_approx=use_approx,
filter_clause=filter_clause,
metadata_clause=metadata_clause,
)

else:
Expand All @@ -462,6 +479,7 @@ def similarity_search_with_score(
vector_weight=vector_weight,
keyword_weight=keyword_weight,
keyword_search_clause=keyword_search_clause,
metadata_clause=metadata_clause,
)

def similarity_search_by_vector(
Expand All @@ -471,6 +489,7 @@ def similarity_search_by_vector(
return_fields: set[str] = set(),
use_approx: bool = True,
filter_clause: str = "",
metadata_clause: str = "",
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Expand All @@ -484,6 +503,9 @@ def similarity_search_by_vector(
use_approx: Whether to use approximate vector search via ANN.
Defaults to True. If False, exact vector search will be used.
filter_clause: Filter clause to apply to the query.
metadata_clause: Optional AQL clause to return additional metadata once
the top k results are retrieved. If specified, the metadata will be
added to the Document.metadata field.

Returns:
List of Documents most similar to the query vector.
Expand All @@ -494,6 +516,7 @@ def similarity_search_by_vector(
return_fields=return_fields,
use_approx=use_approx,
filter_clause=filter_clause,
metadata_clause=metadata_clause,
)

return [doc for doc, _ in results]
Expand All @@ -509,6 +532,7 @@ def similarity_search_by_vector_and_keyword(
vector_weight: float = 1.0,
keyword_weight: float = 1.0,
keyword_search_clause: str = "",
metadata_clause: str = "",
) -> List[Document]:
results = self.similarity_search_by_vector_and_keyword_with_score(
query=query,
Expand All @@ -520,6 +544,7 @@ def similarity_search_by_vector_and_keyword(
vector_weight=vector_weight,
keyword_weight=keyword_weight,
keyword_search_clause=keyword_search_clause,
metadata_clause=metadata_clause,
)

return [doc for doc, _ in results]
Expand All @@ -531,6 +556,7 @@ def similarity_search_by_vector_with_score(
return_fields: set[str] = set(),
use_approx: bool = True,
filter_clause: str = "",
metadata_clause: str = "",
) -> List[tuple[Document, float]]:
"""Return docs most similar to embedding vector.

Expand All @@ -543,6 +569,9 @@ def similarity_search_by_vector_with_score(
use_approx: Whether to use approximate vector search via ANN.
Defaults to True. If False, exact vector search will be used.
filter_clause: Filter clause to apply to the query.
metadata_clause: Optional AQL clause to return additional metadata once
the top k results are retrieved. If specified, the metadata will be
added to the Document.metadata field.
**kwargs: Additional keyword arguments passed to the query execution.

Returns:
Expand All @@ -554,6 +583,7 @@ def similarity_search_by_vector_with_score(
return_fields=return_fields,
use_approx=use_approx,
filter_clause=filter_clause,
metadata_clause=metadata_clause,
)

cursor = self.db.aql.execute(aql_query, bind_vars=bind_vars, stream=True)
Expand All @@ -573,6 +603,7 @@ def similarity_search_by_vector_and_keyword_with_score(
vector_weight: float = 1.0,
keyword_weight: float = 1.0,
keyword_search_clause: str = "",
metadata_clause: str = "",
) -> List[tuple[Document, float]]:
"""Run similarity search with ArangoDB.

Expand All @@ -591,6 +622,9 @@ def similarity_search_by_vector_and_keyword_with_score(
Only used when search_type is "hybrid". Defaults to 1.0.
keyword_search_clause: Optional AQL filter clause to apply Full Text Search.
If empty, a default search clause will be used.
metadata_clause: Optional AQL clause to return additional metadata once
the top k results are retrieved. If specified, the metadata will be
added to the Document.metadata field.

Returns:
List of Documents most similar to the query.
Expand All @@ -606,6 +640,7 @@ def similarity_search_by_vector_and_keyword_with_score(
vector_weight=vector_weight,
keyword_weight=keyword_weight,
keyword_search_clause=keyword_search_clause,
metadata_clause=metadata_clause,
)

cursor = self.db.aql.execute(aql_query, bind_vars=bind_vars, stream=True)
Expand Down Expand Up @@ -980,10 +1015,16 @@ def _process_search_query(self, cursor: Cursor) -> List[tuple[Document, float]]:

while not cursor.empty():
for result in cursor:
data, score = result["data"], result["score"]
data, score, metadata = (
result["data"],
result["score"],
result["metadata"],
)
_key = data.pop("_key")
page_content = data.pop(self.text_field)
doc = Document(page_content=page_content, id=_key, metadata=data)
doc = Document(
page_content=page_content, id=_key, metadata={**data, **metadata}
)

results.append((doc, score))

Expand All @@ -999,6 +1040,7 @@ def _build_vector_search_query(
return_fields: set[str],
use_approx: bool,
filter_clause: str,
metadata_clause: str,
) -> Tuple[str, dict[str, Any]]:
if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
Expand Down Expand Up @@ -1028,7 +1070,8 @@ def _build_vector_search_query(
LIMIT {k}
{filter_clause if use_approx else ""}
LET data = KEEP(doc, {return_fields_list})
RETURN {{data, score}}
LET metadata = {f'({metadata_clause})' if metadata_clause else '{}'}
RETURN {{data, score, metadata}}
"""

bind_vars = {
Expand All @@ -1049,6 +1092,7 @@ def _build_hybrid_search_query(
vector_weight: float,
keyword_weight: float,
keyword_search_clause: str,
metadata_clause: str,
) -> Tuple[str, dict[str, Any]]:
"""Build the hybrid search query using RRF."""

Expand Down Expand Up @@ -1091,7 +1135,10 @@ def _build_hybrid_search_query(
SORT score {sort_order}
LIMIT {k}
{filter_clause if use_approx else ""}
RETURN {{ doc, score }}
WINDOW {{ preceding: "unbounded", following: 0 }}
AGGREGATE rank = COUNT(1)
LET rrf_score = {vector_weight} / ({self.rrf_constant} + rank)
RETURN {{ key: doc._key, score: rrf_score }}
)

LET keyword_results = (
Expand All @@ -1101,39 +1148,24 @@ def _build_hybrid_search_query(
LET score = BM25(doc)
SORT score DESC
LIMIT {k}
RETURN {{ doc, score }}
)

LET rrf_vector = (
FOR i IN RANGE(0, LENGTH(vector_results) - 1)
LET doc = vector_results[i].doc
FILTER doc != null
RETURN {{
doc,
score: {vector_weight} / (@rrf_constant + i + 1)
}}
)

LET rrf_keyword = (
FOR i IN RANGE(0, LENGTH(keyword_results) - 1)
LET doc = keyword_results[i].doc
FILTER doc != null
RETURN {{
doc,
score: {keyword_weight} / (@rrf_constant + i + 1)
}}
WINDOW {{ preceding: "unbounded", following: 0 }}
AGGREGATE rank = COUNT(1)
LET rrf_score = {keyword_weight} / ({self.rrf_constant} + rank)
RETURN {{ key: doc._key, score: rrf_score }}
)

FOR result IN APPEND(rrf_vector, rrf_keyword)
COLLECT doc_key = result.doc._key INTO group
LET rrf_score = SUM(group[*].result.score)
LET doc = group[0].result.doc
SORT rrf_score DESC
LIMIT @rrf_search_limit
RETURN {{
data: KEEP(doc, {return_fields_list}),
score: rrf_score
}}
FOR result IN APPEND(vector_results, keyword_results)
COLLECT key = result.key AGGREGATE score = SUM(result.score)
SORT score DESC
LIMIT {self.rrf_search_limit}
LET data = FIRST(
FOR doc IN @@collection
FILTER doc._key == key
LIMIT 1
RETURN KEEP(doc, {return_fields_list})
)
LET metadata = {f'({metadata_clause})' if metadata_clause else '{}'}
RETURN {{ data, score, metadata }}
"""

bind_vars = {
Expand All @@ -1142,8 +1174,6 @@ def _build_hybrid_search_query(
"embedding": embedding,
"query": query,
"analyzer": self.keyword_analyzer,
"rrf_constant": self.rrf_constant,
"rrf_search_limit": self.rrf_search_limit,
}

return aql_query, bind_vars
Loading