Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 54 additions & 74 deletions libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -142,10 +142,7 @@ def __init__(
DistanceStrategy.DOT_PRODUCT, DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.MAX_INNER_PRODUCT,
]: ]:
m = ( m = "distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', 'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'" # noqa: E501
"distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', "
"'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'"
)
raise ValueError(m) raise ValueError(m)


self.embedding = embedding self.embedding = embedding
Expand Down Expand Up @@ -1223,15 +1220,15 @@ def _process_search_query(self, cursor: Cursor) -> List[tuple[Document, float]]:


return results return results


def _build_vector_search_query( def _get_score_query_and_sort_order(self, use_approx: bool) -> Tuple[str, str]:
self, """Get the score query and sort order for the given distance strategy.
embedding: List[float],
k: int, :param use_approx: Whether to use approximate nearest neighbor search.
return_fields: set[str], :type use_approx: bool
use_approx: bool, :return: A tuple containing the score query and sort order.
filter_clause: str, :rtype: Tuple[str, str]
metadata_clause: str, """
) -> Tuple[str, dict[str, Any]]:
if self._distance_strategy == DistanceStrategy.COSINE: if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY" score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
Expand All @@ -1240,38 +1237,57 @@ def _build_vector_search_query(
score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE" score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "ASC" sort_order = "ASC"
elif self._distance_strategy == DistanceStrategy.JACCARD:
use_approx = False
score_func = "JACCARD"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "DESC"
elif self._distance_strategy in [ elif self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.DOT_PRODUCT, DistanceStrategy.DOT_PRODUCT,
]: ]:
if use_approx: scoring_query = """
raise ValueError( SUM(
f"Unsupported metric: {self._distance_strategy} is not supported " FOR i IN 0..LENGTH(doc.embedding)-1
"for approximate search" RETURN doc.embedding[i] * @embedding[i]
)
if self._distance_strategy == DistanceStrategy.JACCARD:
score_func = "JACCARD"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
elif self._distance_strategy in [
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
scoring_query = (
"SUM(FOR i IN 0..LENGTH(doc.embedding)-1 "
"RETURN doc.embedding[i] * @embedding[i])"
) )
"""
sort_order = "DESC" sort_order = "DESC"
else: else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}") raise ValueError(f"Unsupported metric: {self._distance_strategy}")


if use_approx: return scoring_query, sort_order
if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4." def _ensure_vector_index(self) -> None:
raise ValueError(m) """Ensure the vector index exists."""
if self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
m = f"Unsupported metric: {self._distance_strategy} is not supported for approximate search" # noqa: E501
raise ValueError(m)


if not self.retrieve_vector_index(): if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore
self.create_vector_index() m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()

def _build_vector_search_query(
self,
embedding: List[float],
k: int,
return_fields: set[str],
use_approx: bool,
filter_clause: str,
metadata_clause: str,
) -> Tuple[str, dict[str, Any]]:
scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx)

if use_approx:
self._ensure_vector_index()


return_fields.update({"_key", self.text_field}) return_fields.update({"_key", self.text_field})
return_fields_list = list(return_fields) return_fields_list = list(return_fields)
Expand Down Expand Up @@ -1336,49 +1352,13 @@ def _build_hybrid_search_query(
) -> Tuple[str, dict[str, Any]]: ) -> Tuple[str, dict[str, Any]]:
"""Build the hybrid search query using RRF.""" """Build the hybrid search query using RRF."""


scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx)

if not self.retrieve_keyword_index(): if not self.retrieve_keyword_index():
self.create_keyword_index() self.create_keyword_index()


if self._distance_strategy == DistanceStrategy.COSINE:
score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "DESC"
elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
sort_order = "ASC"
elif self._distance_strategy in [
DistanceStrategy.JACCARD,
DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.DOT_PRODUCT,
]:
if use_approx:
raise ValueError(
f"Unsupported metric: {self._distance_strategy} is not supported "
"for approximate search"
)
if self._distance_strategy == DistanceStrategy.JACCARD:
score_func = "JACCARD"
scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)"
elif self._distance_strategy in [
DistanceStrategy.DOT_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
scoring_query = (
"SUM(FOR i IN 0..LENGTH(doc.embedding)-1 "
"RETURN doc.embedding[i] * @embedding[i])"
)
sort_order = "DESC"
else:
raise ValueError(f"Unsupported metric: {self._distance_strategy}")

if use_approx: if use_approx:
if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore self._ensure_vector_index()
m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4."
raise ValueError(m)

if not self.retrieve_vector_index():
self.create_vector_index()


return_fields.update({"_key", self.text_field}) return_fields.update({"_key", self.text_field})
return_fields_list = list(return_fields) return_fields_list = list(return_fields)
Expand Down