diff --git a/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py b/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py index 56800d7..db42d98 100644 --- a/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py +++ b/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py @@ -142,10 +142,7 @@ def __init__( DistanceStrategy.DOT_PRODUCT, DistanceStrategy.MAX_INNER_PRODUCT, ]: - m = ( - "distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', " - "'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'" - ) + m = "distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', 'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'" # noqa: E501 raise ValueError(m) self.embedding = embedding @@ -1223,15 +1220,15 @@ def _process_search_query(self, cursor: Cursor) -> List[tuple[Document, float]]: return results - def _build_vector_search_query( - self, - embedding: List[float], - k: int, - return_fields: set[str], - use_approx: bool, - filter_clause: str, - metadata_clause: str, - ) -> Tuple[str, dict[str, Any]]: + def _get_score_query_and_sort_order(self, use_approx: bool) -> Tuple[str, str]: + """Get the score query and sort order for the given distance strategy. + + :param use_approx: Whether to use approximate nearest neighbor search. + :type use_approx: bool + :return: A tuple containing the score query and sort order. + :rtype: Tuple[str, str] + """ + if self._distance_strategy == DistanceStrategy.COSINE: score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY" scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" @@ -1240,38 +1237,57 @@ def _build_vector_search_query( score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE" scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" sort_order = "ASC" + elif self._distance_strategy == DistanceStrategy.JACCARD: + use_approx = False + score_func = "JACCARD" + scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" + sort_order = "DESC" elif self._distance_strategy in [ - DistanceStrategy.JACCARD, DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.DOT_PRODUCT, ]: - if use_approx: - raise ValueError( - f"Unsupported metric: {self._distance_strategy} is not supported " - "for approximate search" - ) - if self._distance_strategy == DistanceStrategy.JACCARD: - score_func = "JACCARD" - scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" - elif self._distance_strategy in [ - DistanceStrategy.DOT_PRODUCT, - DistanceStrategy.MAX_INNER_PRODUCT, - ]: - scoring_query = ( - "SUM(FOR i IN 0..LENGTH(doc.embedding)-1 " - "RETURN doc.embedding[i] * @embedding[i])" + scoring_query = """ + SUM( + FOR i IN 0..LENGTH(doc.embedding)-1 + RETURN doc.embedding[i] * @embedding[i] ) + """ sort_order = "DESC" else: raise ValueError(f"Unsupported metric: {self._distance_strategy}") - if use_approx: - if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore - m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4." - raise ValueError(m) + return scoring_query, sort_order + + def _ensure_vector_index(self) -> None: + """Ensure the vector index exists.""" + if self._distance_strategy in [ + DistanceStrategy.JACCARD, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.MAX_INNER_PRODUCT, + ]: + m = f"Unsupported metric: {self._distance_strategy} is not supported for approximate search" # noqa: E501 + raise ValueError(m) - if not self.retrieve_vector_index(): - self.create_vector_index() + if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore + m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4." + raise ValueError(m) + + if not self.retrieve_vector_index(): + self.create_vector_index() + + def _build_vector_search_query( + self, + embedding: List[float], + k: int, + return_fields: set[str], + use_approx: bool, + filter_clause: str, + metadata_clause: str, + ) -> Tuple[str, dict[str, Any]]: + scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx) + + if use_approx: + self._ensure_vector_index() return_fields.update({"_key", self.text_field}) return_fields_list = list(return_fields) @@ -1336,49 +1352,13 @@ def _build_hybrid_search_query( ) -> Tuple[str, dict[str, Any]]: """Build the hybrid search query using RRF.""" + scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx) + if not self.retrieve_keyword_index(): self.create_keyword_index() - if self._distance_strategy == DistanceStrategy.COSINE: - score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY" - scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" - sort_order = "DESC" - elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE" - scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" - sort_order = "ASC" - elif self._distance_strategy in [ - DistanceStrategy.JACCARD, - DistanceStrategy.MAX_INNER_PRODUCT, - DistanceStrategy.DOT_PRODUCT, - ]: - if use_approx: - raise ValueError( - f"Unsupported metric: {self._distance_strategy} is not supported " - "for approximate search" - ) - if self._distance_strategy == DistanceStrategy.JACCARD: - score_func = "JACCARD" - scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" - elif self._distance_strategy in [ - DistanceStrategy.DOT_PRODUCT, - DistanceStrategy.MAX_INNER_PRODUCT, - ]: - scoring_query = ( - "SUM(FOR i IN 0..LENGTH(doc.embedding)-1 " - "RETURN doc.embedding[i] * @embedding[i])" - ) - sort_order = "DESC" - else: - raise ValueError(f"Unsupported metric: {self._distance_strategy}") - if use_approx: - if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore - m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4." - raise ValueError(m) - - if not self.retrieve_vector_index(): - self.create_vector_index() + self._ensure_vector_index() return_fields.update({"_key", self.text_field}) return_fields_list = list(return_fields)