Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 54 additions & 74 deletions libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,7 @@ def __init__(
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
m = (
"distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', "
"'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'"
)
m = "distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', 'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'" # noqa: E501
raise ValueError(m)

self.embedding = embedding
Expand Down Expand Up @@ -1223,15 +1220,15 @@ def _process_search_query(self, cursor: Cursor) -> List[tuple[Document, float]]:

return results

def _build_vector_search_query(
self,
embedding: List[float],
k: int,
return_fields: set[str],
use_approx: bool,
filter_clause: str,
metadata_clause: str,
) -> Tuple[str, dict[str, Any]]:
def _get_score_query_and_sort_order(self, use_approx: bool) -> Tuple[str, str]:
"""Get the score query and sort order for the given distance strategy.

:param use_approx: Whether to use approximate nearest neighbor search.
:type use_approx: bool
:return: A tuple containing the score query and sort order.
:rtype: Tuple[str, str]
"""

if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
Expand All @@ -1240,38 +1237,57 @@ def _build_vector_search_query(
score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "ASC"
elif self._distance_strategy == DistanceStrategy.JACCARD:
use_approx = False
score_func = "JACCARD"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "DESC"
elif self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.DOT_PRODUCT,
]:
if use_approx:
raise ValueError(
f"Unsupported metric: {self._distance_strategy} is not supported "
"for approximate search"
)
if self._distance_strategy == DistanceStrategy.JACCARD:
score_func = "JACCARD"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
elif self._distance_strategy in [
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
scoring_query = (
"SUM(FOR i IN 0..LENGTH(doc.embedding)-1 "
"RETURN doc.embedding[i] * @embedding[i])"
scoring_query = """
SUM(
FOR i IN 0..LENGTH(doc.embedding)-1
RETURN doc.embedding[i] * @embedding[i]
)
"""
sort_order = "DESC"
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")

if use_approx:
if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)
return scoring_query, sort_order

def _ensure_vector_index(self) -> None:
"""Ensure the vector index exists."""
if self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
m = f"Unsupported metric: {self._distance_strategy} is not supported for approximate search" # noqa: E501
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()
if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()

def _build_vector_search_query(
self,
embedding: List[float],
k: int,
return_fields: set[str],
use_approx: bool,
filter_clause: str,
metadata_clause: str,
) -> Tuple[str, dict[str, Any]]:
scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx)

if use_approx:
self._ensure_vector_index()

return_fields.update({"_key", self.text_field})
return_fields_list = list(return_fields)
Expand Down Expand Up @@ -1336,49 +1352,13 @@ def _build_hybrid_search_query(
) -> Tuple[str, dict[str, Any]]:
"""Build the hybrid search query using RRF."""

scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx)

if not self.retrieve_keyword_index():
self.create_keyword_index()

if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "DESC"
elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "ASC"
elif self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.DOT_PRODUCT,
]:
if use_approx:
raise ValueError(
f"Unsupported metric: {self._distance_strategy} is not supported "
"for approximate search"
)
if self._distance_strategy == DistanceStrategy.JACCARD:
score_func = "JACCARD"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
elif self._distance_strategy in [
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
scoring_query = (
"SUM(FOR i IN 0..LENGTH(doc.embedding)-1 "
"RETURN doc.embedding[i] * @embedding[i])"
)
sort_order = "DESC"
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")

if use_approx:
if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()
self._ensure_vector_index()

return_fields.update({"_key", self.text_field})
return_fields_list = list(return_fields)
Expand Down