Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added libs/arangodb/.coverage
Binary file not shown.
201 changes: 148 additions & 53 deletions libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@ def __init__(
if distance_strategy not in [
DistanceStrategy.COSINE,
DistanceStrategy.EUCLIDEAN_DISTANCE,
DistanceStrategy.JACCARD,
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
m = "distance_strategy must be 'COSINE' or 'EUCLIDEAN_DISTANCE'"
m = "distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', 'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'" # noqa: E501
raise ValueError(m)

self.embedding = embedding
Expand Down Expand Up @@ -1217,6 +1220,61 @@ def _process_search_query(self, cursor: Cursor) -> List[tuple[Document, float]]:

return results

def _get_score_query_and_sort_order(self, use_approx: bool) -> Tuple[str, str]:
"""Get the score query and sort order for the given distance strategy.

:param use_approx: Whether to use approximate nearest neighbor search.
:type use_approx: bool
:return: A tuple containing the score query and sort order.
:rtype: Tuple[str, str]
"""

if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "DESC"
elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "ASC"
elif self._distance_strategy == DistanceStrategy.JACCARD:
use_approx = False
score_func = "JACCARD"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "DESC"
elif self._distance_strategy in [
DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.DOT_PRODUCT,
]:
scoring_query = """
SUM(
FOR i IN 0..LENGTH(doc.embedding)-1
RETURN doc.embedding[i] * @embedding[i]
)
"""
sort_order = "DESC"
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")

return scoring_query, sort_order

def _ensure_vector_index(self) -> None:
"""Ensure the vector index exists."""
if self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
m = f"Unsupported metric: {self._distance_strategy} is not supported for approximate search" # noqa: E501
raise ValueError(m)

if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()

def _build_vector_search_query(
self,
embedding: List[float],
Expand All @@ -1226,37 +1284,51 @@ def _build_vector_search_query(
filter_clause: str,
metadata_clause: str,
) -> Tuple[str, dict[str, Any]]:
if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
sort_order = "DESC"
elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE"
sort_order = "ASC"
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")
scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx)

if use_approx:
if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()
self._ensure_vector_index()

return_fields.update({"_key", self.text_field})
return_fields_list = list(return_fields)

aql_query = f"""
FOR doc IN @@collection
{filter_clause if not use_approx else ""}
LET score = {score_func}(doc.{self.embedding_field}, @embedding)
SORT score {sort_order}
LIMIT {k}
{filter_clause if use_approx else ""}
LET data = KEEP(doc, {return_fields_list})
LET metadata = {f"({metadata_clause})" if metadata_clause else "{}"}
RETURN {{data, score, metadata}}
"""
if self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.COSINE,
DistanceStrategy.EUCLIDEAN_DISTANCE,
DistanceStrategy.DOT_PRODUCT,
]:
aql_query = f"""
FOR doc IN @@collection
{filter_clause if not use_approx else ""}
LET score = {scoring_query}
SORT score {sort_order}
LIMIT {k}
{filter_clause if use_approx else ""}
LET data = KEEP(doc, {return_fields_list})
LET metadata = {f"({metadata_clause})" if metadata_clause else "{}"}
RETURN {{data, score, metadata}}
"""
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
aql_query = f"""
LET scored = (
FOR doc IN @@collection
{filter_clause}
LET score = {scoring_query}
SORT score {sort_order}
LIMIT {k}
RETURN {{doc, score}}
)
LET maxScore = MAX(scored[*].score)

FOR item IN scored
FILTER item.score == maxScore
LET data = KEEP(item.doc, {return_fields_list})
LET metadata = {f"({metadata_clause})" if metadata_clause else "{}"}
RETURN {{data, score: item.score, metadata}}
"""
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")

bind_vars = {
"@collection": self.collection_name,
Expand All @@ -1280,25 +1352,13 @@ def _build_hybrid_search_query(
) -> Tuple[str, dict[str, Any]]:
"""Build the hybrid search query using RRF."""

scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx)

if not self.retrieve_keyword_index():
self.create_keyword_index()

if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
sort_order = "DESC"
elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE"
sort_order = "ASC"
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")

if use_approx:
if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()
self._ensure_vector_index()

return_fields.update({"_key", self.text_field})
return_fields_list = list(return_fields)
Expand All @@ -1311,19 +1371,54 @@ def _build_hybrid_search_query(
)
"""

if self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.COSINE,
DistanceStrategy.EUCLIDEAN_DISTANCE,
DistanceStrategy.DOT_PRODUCT,
]:
vector_search_query = f"""
LET vector_results = (
FOR doc IN @@collection
{filter_clause if not use_approx else ""}
LET score = {scoring_query}
SORT score {sort_order}
LIMIT {k}
{filter_clause if use_approx else ""}
WINDOW {{ preceding: "unbounded", following: 0 }}
AGGREGATE rank = COUNT(1)
LET rrf_score = {vector_weight} / ({self.rrf_constant} + rank)
RETURN {{ key: doc._key, score: rrf_score }}
)
"""
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
vector_search_query = f"""
LET scored = (
FOR doc IN @@collection
{filter_clause}
LET score = SUM(
FOR i IN 0..LENGTH(doc.embedding)-1
RETURN doc.embedding[i] * @embedding[i]
)
SORT score {sort_order}
LIMIT {k}
RETURN {{doc, score}}
)
LET maxScore = MAX(scored[*].score)

LET vector_results = (
FOR item IN scored
FILTER item.score == maxScore
LET rank = 1
LET rrf_score = {vector_weight} / ({self.rrf_constant} + rank)
RETURN {{ key: item.doc._key, score: rrf_score }}
)
"""
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")

aql_query = f"""
LET vector_results = (
FOR doc IN @@collection
{filter_clause if not use_approx else ""}
LET score = {score_func}(doc.{self.embedding_field}, @embedding)
SORT score {sort_order}
LIMIT {k}
{filter_clause if use_approx else ""}
WINDOW {{ preceding: "unbounded", following: 0 }}
AGGREGATE rank = COUNT(1)
LET rrf_score = {vector_weight} / ({self.rrf_constant} + rank)
RETURN {{ key: doc._key, score: rrf_score }}
)
{vector_search_query}

LET keyword_results = (
FOR doc IN @@view
Expand Down
8 changes: 7 additions & 1 deletion libs/arangodb/langchain_arangodb/vectorstores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

class DistanceStrategy(str, Enum):
"""Enumerator of the Distance strategies for calculating distances
between vectors."""
between vectors.

Note that **use_approx** is not supported for the following distance strategies:
- JACCARD
- MAX_INNER_PRODUCT
- DOT_PRODUCT
"""

EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
Expand Down
Loading