diff --git a/libs/arangodb/.coverage b/libs/arangodb/.coverage new file mode 100644 index 0000000..ec75e4c Binary files /dev/null and b/libs/arangodb/.coverage differ diff --git a/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py b/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py index d9c9531..db42d98 100644 --- a/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py +++ b/libs/arangodb/langchain_arangodb/vectorstores/arangodb_vector.py @@ -138,8 +138,11 @@ def __init__( if distance_strategy not in [ DistanceStrategy.COSINE, DistanceStrategy.EUCLIDEAN_DISTANCE, + DistanceStrategy.JACCARD, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.MAX_INNER_PRODUCT, ]: - m = "distance_strategy must be 'COSINE' or 'EUCLIDEAN_DISTANCE'" + m = "distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', 'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'" # noqa: E501 raise ValueError(m) self.embedding = embedding @@ -1217,6 +1220,61 @@ def _process_search_query(self, cursor: Cursor) -> List[tuple[Document, float]]: return results + def _get_score_query_and_sort_order(self, use_approx: bool) -> Tuple[str, str]: + """Get the score query and sort order for the given distance strategy. + + :param use_approx: Whether to use approximate nearest neighbor search. + :type use_approx: bool + :return: A tuple containing the score query and sort order. + :rtype: Tuple[str, str] + """ + + if self._distance_strategy == DistanceStrategy.COSINE: + score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY" + scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" + sort_order = "DESC" + elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: + score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE" + scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" + sort_order = "ASC" + elif self._distance_strategy == DistanceStrategy.JACCARD: + use_approx = False + score_func = "JACCARD" + scoring_query = f"{score_func}(doc.{self.embedding_field}, @embedding)" + sort_order = "DESC" + elif self._distance_strategy in [ + DistanceStrategy.MAX_INNER_PRODUCT, + DistanceStrategy.DOT_PRODUCT, + ]: + scoring_query = """ + SUM( + FOR i IN 0..LENGTH(doc.embedding)-1 + RETURN doc.embedding[i] * @embedding[i] + ) + """ + sort_order = "DESC" + else: + raise ValueError(f"Unsupported metric: {self._distance_strategy}") + + return scoring_query, sort_order + + def _ensure_vector_index(self) -> None: + """Ensure the vector index exists.""" + if self._distance_strategy in [ + DistanceStrategy.JACCARD, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.MAX_INNER_PRODUCT, + ]: + m = f"Unsupported metric: {self._distance_strategy} is not supported for approximate search" # noqa: E501 + raise ValueError(m) + + if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore + m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4." + raise ValueError(m) + + if not self.retrieve_vector_index(): + self.create_vector_index() + def _build_vector_search_query( self, embedding: List[float], @@ -1226,37 +1284,51 @@ def _build_vector_search_query( filter_clause: str, metadata_clause: str, ) -> Tuple[str, dict[str, Any]]: - if self._distance_strategy == DistanceStrategy.COSINE: - score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY" - sort_order = "DESC" - elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE" - sort_order = "ASC" - else: - raise ValueError(f"Unsupported metric: {self._distance_strategy}") + scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx) if use_approx: - if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore - m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4." - raise ValueError(m) - - if not self.retrieve_vector_index(): - self.create_vector_index() + self._ensure_vector_index() return_fields.update({"_key", self.text_field}) return_fields_list = list(return_fields) - aql_query = f""" - FOR doc IN @@collection - {filter_clause if not use_approx else ""} - LET score = {score_func}(doc.{self.embedding_field}, @embedding) - SORT score {sort_order} - LIMIT {k} - {filter_clause if use_approx else ""} - LET data = KEEP(doc, {return_fields_list}) - LET metadata = {f"({metadata_clause})" if metadata_clause else "{}"} - RETURN {{data, score, metadata}} - """ + if self._distance_strategy in [ + DistanceStrategy.JACCARD, + DistanceStrategy.COSINE, + DistanceStrategy.EUCLIDEAN_DISTANCE, + DistanceStrategy.DOT_PRODUCT, + ]: + aql_query = f""" + FOR doc IN @@collection + {filter_clause if not use_approx else ""} + LET score = {scoring_query} + SORT score {sort_order} + LIMIT {k} + {filter_clause if use_approx else ""} + LET data = KEEP(doc, {return_fields_list}) + LET metadata = {f"({metadata_clause})" if metadata_clause else "{}"} + RETURN {{data, score, metadata}} + """ + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + aql_query = f""" + LET scored = ( + FOR doc IN @@collection + {filter_clause} + LET score = {scoring_query} + SORT score {sort_order} + LIMIT {k} + RETURN {{doc, score}} + ) + LET maxScore = MAX(scored[*].score) + + FOR item IN scored + FILTER item.score == maxScore + LET data = KEEP(item.doc, {return_fields_list}) + LET metadata = {f"({metadata_clause})" if metadata_clause else "{}"} + RETURN {{data, score: item.score, metadata}} + """ + else: + raise ValueError(f"Unsupported metric: {self._distance_strategy}") bind_vars = { "@collection": self.collection_name, @@ -1280,25 +1352,13 @@ def _build_hybrid_search_query( ) -> Tuple[str, dict[str, Any]]: """Build the hybrid search query using RRF.""" + scoring_query, sort_order = self._get_score_query_and_sort_order(use_approx) + if not self.retrieve_keyword_index(): self.create_keyword_index() - if self._distance_strategy == DistanceStrategy.COSINE: - score_func = "APPROX_NEAR_COSINE" if use_approx else "COSINE_SIMILARITY" - sort_order = "DESC" - elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - score_func = "APPROX_NEAR_L2" if use_approx else "L2_DISTANCE" - sort_order = "ASC" - else: - raise ValueError(f"Unsupported metric: {self._distance_strategy}") - if use_approx: - if version.parse(self.db.version()) < version.parse("3.12.4"): # type: ignore - m = "Approximate Nearest Neighbor search requires ArangoDB >= 3.12.4." - raise ValueError(m) - - if not self.retrieve_vector_index(): - self.create_vector_index() + self._ensure_vector_index() return_fields.update({"_key", self.text_field}) return_fields_list = list(return_fields) @@ -1311,19 +1371,54 @@ def _build_hybrid_search_query( ) """ + if self._distance_strategy in [ + DistanceStrategy.JACCARD, + DistanceStrategy.COSINE, + DistanceStrategy.EUCLIDEAN_DISTANCE, + DistanceStrategy.DOT_PRODUCT, + ]: + vector_search_query = f""" + LET vector_results = ( + FOR doc IN @@collection + {filter_clause if not use_approx else ""} + LET score = {scoring_query} + SORT score {sort_order} + LIMIT {k} + {filter_clause if use_approx else ""} + WINDOW {{ preceding: "unbounded", following: 0 }} + AGGREGATE rank = COUNT(1) + LET rrf_score = {vector_weight} / ({self.rrf_constant} + rank) + RETURN {{ key: doc._key, score: rrf_score }} + ) + """ + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + vector_search_query = f""" + LET scored = ( + FOR doc IN @@collection + {filter_clause} + LET score = SUM( + FOR i IN 0..LENGTH(doc.embedding)-1 + RETURN doc.embedding[i] * @embedding[i] + ) + SORT score {sort_order} + LIMIT {k} + RETURN {{doc, score}} + ) + LET maxScore = MAX(scored[*].score) + + LET vector_results = ( + FOR item IN scored + FILTER item.score == maxScore + LET rank = 1 + LET rrf_score = {vector_weight} / ({self.rrf_constant} + rank) + RETURN {{ key: item.doc._key, score: rrf_score }} + ) + """ + else: + raise ValueError(f"Unsupported metric: {self._distance_strategy}") + aql_query = f""" - LET vector_results = ( - FOR doc IN @@collection - {filter_clause if not use_approx else ""} - LET score = {score_func}(doc.{self.embedding_field}, @embedding) - SORT score {sort_order} - LIMIT {k} - {filter_clause if use_approx else ""} - WINDOW {{ preceding: "unbounded", following: 0 }} - AGGREGATE rank = COUNT(1) - LET rrf_score = {vector_weight} / ({self.rrf_constant} + rank) - RETURN {{ key: doc._key, score: rrf_score }} - ) + {vector_search_query} LET keyword_results = ( FOR doc IN @@view diff --git a/libs/arangodb/langchain_arangodb/vectorstores/utils.py b/libs/arangodb/langchain_arangodb/vectorstores/utils.py index 5ddda4c..12cd89c 100644 --- a/libs/arangodb/langchain_arangodb/vectorstores/utils.py +++ b/libs/arangodb/langchain_arangodb/vectorstores/utils.py @@ -3,7 +3,13 @@ class DistanceStrategy(str, Enum): """Enumerator of the Distance strategies for calculating distances - between vectors.""" + between vectors. + + Note that **use_approx** is not supported for the following distance strategies: + - JACCARD + - MAX_INNER_PRODUCT + - DOT_PRODUCT + """ EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" diff --git a/libs/arangodb/tests/integration_tests/vectorstores/test_arangodb_vector.py b/libs/arangodb/tests/integration_tests/vectorstores/test_arangodb_vector.py index e8d71d9..e648172 100644 --- a/libs/arangodb/tests/integration_tests/vectorstores/test_arangodb_vector.py +++ b/libs/arangodb/tests/integration_tests/vectorstores/test_arangodb_vector.py @@ -1424,3 +1424,216 @@ def test_arangovector_hybrid_search_error_cases( # Should still return results (vector-only search) assert len(results_vector_only) >= 0 # May return 0 or more results + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_jaccard_distance( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test ArangoVector with Jaccard distance.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["foo", "bar", "baz"] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + distance_strategy=DistanceStrategy.JACCARD, + overwrite_index=False, + ) + + query = "foo" + results = vector_store.similarity_search(query, k=1, use_approx=False) + assert len(results) == 1 + assert results[0].page_content == "foo" + + # Test with scores + results_with_scores = vector_store.similarity_search_with_score( + query, k=1, use_approx=False + ) + assert len(results_with_scores) == 1 + assert 0.0 <= results_with_scores[0][1] <= 1.0 + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_dot_product_distance( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test ArangoVector with Dot Product distance.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["foo", "bar", "baz"] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + distance_strategy=DistanceStrategy.DOT_PRODUCT, + overwrite_index=False, + ) + + query = "foo" + results = vector_store.similarity_search(query, k=1, use_approx=False) + assert len(results) == 1 + assert results[0].page_content == "foo" + + # Test with scores + results_with_scores = vector_store.similarity_search_with_score( + query, k=1, use_approx=False + ) + assert len(results_with_scores) == 1 + assert isinstance(results_with_scores[0][1], (int, float)) + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_max_inner_product_distance( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test ArangoVector with Max Inner Product distance.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["foo", "bar", "baz"] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT, + overwrite_index=False, + ) + + query = "foo" + results = vector_store.similarity_search(query, k=1, use_approx=False) + assert len(results) == 1 + assert results[0].page_content == "foo" + + # Test with scores + results_with_scores = vector_store.similarity_search_with_score( + query, k=1, use_approx=False + ) + assert len(results_with_scores) == 1 + assert isinstance(results_with_scores[0][1], (int, float)) + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_jaccard_hybrid_search( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test JACCARD hybrid search.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # Clean up any leftover views from previous runs + try: + db.delete_view("keyword_index_jaccard_hybrid") + except Exception: + pass + + vector_store = ArangoVector.from_texts( + texts=["foo document", "bar document"], + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + keyword_index_name="keyword_index_jaccard_hybrid", + distance_strategy=DistanceStrategy.JACCARD, + search_type=SearchType.HYBRID, + insert_text=True, + ) + + results = vector_store.similarity_search( + "foo", k=2, search_type=SearchType.HYBRID, use_approx=False + ) + assert len(results) >= 1 + assert results[0].page_content == "foo document" + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_dot_product_hybrid_search( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test DOT_PRODUCT hybrid search.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # Clean up any leftover views from previous runs + try: + db.delete_view("keyword_index_dot_hybrid") + except Exception: + pass + + vector_store = ArangoVector.from_texts( + texts=["foo document", "bar document"], + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + keyword_index_name="keyword_index_dot_hybrid", + distance_strategy=DistanceStrategy.DOT_PRODUCT, + search_type=SearchType.HYBRID, + insert_text=True, + ) + + results = vector_store.similarity_search( + "foo", k=2, search_type=SearchType.HYBRID, use_approx=False + ) + assert len(results) >= 1 + assert results[0].page_content == "foo document" + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_max_inner_product_hybrid_search( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test MAX_INNER_PRODUCT hybrid search.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # Clean up any leftover views from previous runs + try: + db.delete_view("keyword_index_max_hybrid") + except Exception: + pass + + vector_store = ArangoVector.from_texts( + texts=["foo document", "bar document"], + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + keyword_index_name="keyword_index_max_hybrid", + distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT, + search_type=SearchType.HYBRID, + insert_text=True, + ) + + results = vector_store.similarity_search( + "foo", k=2, search_type=SearchType.HYBRID, use_approx=False + ) + assert len(results) >= 1 + assert results[0].page_content == "foo document" diff --git a/libs/arangodb/tests/unit_tests/vectorstores/test_arangodb.py b/libs/arangodb/tests/unit_tests/vectorstores/test_arangodb.py index d0bd627..9c40789 100644 --- a/libs/arangodb/tests/unit_tests/vectorstores/test_arangodb.py +++ b/libs/arangodb/tests/unit_tests/vectorstores/test_arangodb.py @@ -141,9 +141,11 @@ def test_init_with_invalid_distance_strategy() -> None: distance_strategy="INVALID_STRATEGY", # type: ignore ) - assert "distance_strategy must be 'COSINE' or 'EUCLIDEAN_DISTANCE'" in str( - exc_info.value + expected_message = ( + "distance_strategy must be one of: 'COSINE', 'EUCLIDEAN_DISTANCE', " + "'JACCARD', 'DOT_PRODUCT', 'MAX_INNER_PRODUCT'" ) + assert expected_message in str(exc_info.value) def test_collection_creation_if_not_exists(arango_vector_factory: Any) -> None: @@ -1112,6 +1114,116 @@ def test_build_hybrid_search_query_euclidean_distance( assert "SORT score ASC" in query # Euclidean uses ascending sort +def test_build_hybrid_search_query_dot_product( + arango_vector_factory: Any, +) -> None: + """Test _build_hybrid_search_query with DOT_PRODUCT distance strategy.""" + from langchain_arangodb.vectorstores.utils import DistanceStrategy + + vector_store = arango_vector_factory(distance_strategy=DistanceStrategy.DOT_PRODUCT) + + # Mock dependencies + with patch.object( + vector_store, "retrieve_keyword_index", return_value={"name": "test_view"} + ): + with patch.object( + vector_store, "retrieve_vector_index", return_value={"name": "test_index"} + ): + query, bind_vars = vector_store._build_hybrid_search_query( + query="test", + k=2, + embedding=[0.1] * 64, + return_fields=set(), + use_approx=False, + filter_clause="", + vector_weight=0.7, + keyword_weight=0.3, + keyword_search_clause="", + metadata_clause="", + ) + + # Should use manual dot product calculation + assert "SUM(" in query + assert "doc.embedding[i] * @embedding[i]" in query + assert "SORT score DESC" in query + assert 'WINDOW { preceding: "unbounded", following: 0 }' in query + assert "AGGREGATE rank = COUNT(1)" in query + + +def test_build_hybrid_search_query_max_inner_product( + arango_vector_factory: Any, +) -> None: + """Test _build_hybrid_search_query with MAX_INNER_PRODUCT distance strategy.""" + from langchain_arangodb.vectorstores.utils import DistanceStrategy + + vector_store = arango_vector_factory( + distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT + ) + + # Mock dependencies + with patch.object( + vector_store, "retrieve_keyword_index", return_value={"name": "test_view"} + ): + with patch.object( + vector_store, "retrieve_vector_index", return_value={"name": "test_index"} + ): + query, bind_vars = vector_store._build_hybrid_search_query( + query="test", + k=3, + embedding=[0.2] * 64, + return_fields=set(), + use_approx=False, + filter_clause="", + vector_weight=0.6, + keyword_weight=0.4, + keyword_search_clause="", + metadata_clause="", + ) + + assert "SUM(" in query + assert "doc.embedding[i] * @embedding[i]" in query + assert "SORT score DESC" in query + assert "LET scored = (" in query + assert "LET maxScore = MAX(scored[*].score)" in query + assert "FILTER item.score == maxScore" in query + assert "item.doc._key" in query + + +def test_build_hybrid_search_query_jaccard( + arango_vector_factory: Any, +) -> None: + """Test _build_hybrid_search_query with JACCARD distance strategy.""" + from langchain_arangodb.vectorstores.utils import DistanceStrategy + + vector_store = arango_vector_factory(distance_strategy=DistanceStrategy.JACCARD) + + # Mock dependencies + with patch.object( + vector_store, "retrieve_keyword_index", return_value={"name": "test_view"} + ): + with patch.object( + vector_store, "retrieve_vector_index", return_value={"name": "test_index"} + ): + query, bind_vars = vector_store._build_hybrid_search_query( + query="test", + k=2, + embedding=[0.1] * 64, + return_fields=set(), + use_approx=False, + filter_clause="", + vector_weight=0.8, + keyword_weight=0.2, + keyword_search_clause="", + metadata_clause="", + ) + + # Should use JACCARD built-in function + assert "JACCARD" in query + assert "SORT score DESC" in query + assert 'WINDOW { preceding: "unbounded", following: 0 }' in query + assert "AGGREGATE rank = COUNT(1)" in query + + def test_build_hybrid_search_query_version_check(arango_vector_factory: Any) -> None: """Test that _build_hybrid_search_query checks ArangoDB version for approximate search.""" @@ -1184,3 +1296,197 @@ def test_search_type_override_in_similarity_search(arango_vector_factory: Any) - mock_vector_search.assert_called_once() assert docs == expected_docs + + +def test_jaccard_distance_strategy(arango_vector_factory: Any) -> None: + """Test JACCARD calculation with actual vectors and query generation.""" + vector_store = arango_vector_factory(distance_strategy=DistanceStrategy.JACCARD) + + query_vector = [1.0, 1.0, 0.0] + + # Test query generation + query, bind_vars = vector_store._build_vector_search_query( + embedding=query_vector, + k=1, + return_fields={"text"}, + use_approx=False, + filter_clause="", + metadata_clause="", + ) + + # Verify JACCARD function is used correctly + assert "JACCARD(doc.embedding, @embedding)" in query + assert bind_vars["embedding"] == query_vector + assert "DESC" in query + + # Test mathematical correctness with real document data + docs = [ + {"_key": "doc1", "text": "identical", "embedding": [1.0, 1.0, 0.0]}, + {"_key": "doc2", "text": "partial", "embedding": [1.0, 0.0, 1.0]}, + {"_key": "doc3", "text": "different", "embedding": [0.0, 0.0, 1.0]}, + ] + + mock_cursor = MagicMock() + # Set up cursor to be empty after first iteration + mock_cursor.empty.side_effect = [False, True] + mock_cursor.has_more.return_value = False + mock_cursor.__iter__ = MagicMock( + return_value=iter( + [ + {"data": docs[0], "score": 1.0, "metadata": {}}, + {"data": docs[1], "score": 0.33, "metadata": {}}, + {"data": docs[2], "score": 0.0, "metadata": {}}, + ] + ) + ) + vector_store.db.aql.execute.return_value = mock_cursor + + # Query with vector [1,1,0] + results = vector_store.similarity_search_with_score("test", k=3, use_approx=False) + scores = [score for _, score in results] + + # Verify results + assert scores == [1.0, 0.33, 0.0] + assert scores[0] > scores[1] > scores[2] + assert len(results) == 3 + + +def test_dot_product_distance_strategy(arango_vector_factory: Any) -> None: + """Test DOT_PRODUCT calculation with actual vectors and query generation.""" + vector_store = arango_vector_factory(distance_strategy=DistanceStrategy.DOT_PRODUCT) + + query_vector = [1.0, 2.0, 3.0] + + # Test query generation + query, bind_vars = vector_store._build_vector_search_query( + embedding=query_vector, + k=1, + return_fields={"text"}, + use_approx=False, + filter_clause="", + metadata_clause="", + ) + + # Verify dot product formula: sum of element-wise multiplication + assert "SUM(" in query + assert "doc.embedding[i] * @embedding[i]" in query + assert "FOR i IN 0..LENGTH(doc.embedding)-1" in query + assert bind_vars["embedding"] == query_vector + assert "DESC" in query + + # Test mathematical correctness with real document data + docs = [ + {"_key": "doc1", "text": "high_score", "embedding": [2.0, 3.0, 1.0]}, + {"_key": "doc2", "text": "medium_score", "embedding": [1.0, 1.0, 1.0]}, + {"_key": "doc3", "text": "low_score", "embedding": [1.0, 0.0, 0.0]}, + ] + + mock_cursor = MagicMock() + # Set up cursor to be empty after first iteration + mock_cursor.empty.side_effect = [False, True] + mock_cursor.has_more.return_value = False + mock_cursor.__iter__ = MagicMock( + return_value=iter( + [ + {"data": docs[0], "score": 11.0, "metadata": {}}, + {"data": docs[1], "score": 6.0, "metadata": {}}, + {"data": docs[2], "score": 1.0, "metadata": {}}, + ] + ) + ) + vector_store.db.aql.execute.return_value = mock_cursor + + # Query with vector [1,2,3] + results = vector_store.similarity_search_with_score("test", k=3, use_approx=False) + scores = [score for _, score in results] + + # Verify results + assert scores == [11.0, 6.0, 1.0] + assert scores[0] > scores[1] > scores[2] + assert len(results) == 3 + + +def test_max_inner_product_strategy(arango_vector_factory: Any) -> None: + """Test MAX_INNER_PRODUCT calculation with actual vectors and query generation.""" + vector_store = arango_vector_factory( + distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT + ) + + query_vector = [1.0, 2.0, 3.0] + + # Test query generation + query, bind_vars = vector_store._build_vector_search_query( + embedding=query_vector, + k=1, + return_fields={"text"}, + use_approx=False, + filter_clause="", + metadata_clause="", + ) + + # Verify inner product calculation (similar to dot product) + assert "SUM(" in query + assert "doc.embedding[i] * @embedding[i]" in query + assert bind_vars["embedding"] == query_vector + assert "DESC" in query + + # Test mathematical correctness with real document data + docs = [ + {"_key": "doc1", "text": "high_score", "embedding": [2.0, 3.0, 1.0]}, + {"_key": "doc2", "text": "medium_score", "embedding": [1.0, 1.0, 1.0]}, + {"_key": "doc3", "text": "low_score", "embedding": [1.0, 0.0, 0.0]}, + ] + + mock_cursor = MagicMock() + # Set up cursor to be empty after first iteration + mock_cursor.empty.side_effect = [False, True] + mock_cursor.has_more.return_value = False + mock_cursor.__iter__ = MagicMock( + return_value=iter( + [ + {"data": docs[0], "score": 11.0, "metadata": {}}, + {"data": docs[1], "score": 6.0, "metadata": {}}, + {"data": docs[2], "score": 1.0, "metadata": {}}, + ] + ) + ) + vector_store.db.aql.execute.return_value = mock_cursor + + # Query with vector [1, 2, 3] + results = vector_store.similarity_search_with_score("test", k=3, use_approx=False) + scores = [score for _, score in results] + + # Verify inner product calculations + assert scores == [11.0, 6.0, 1.0] + assert len(results) == 3 + assert scores[0] > scores[1] > scores[2] + + +def test_distance_strategy_reject_approximate_search( + arango_vector_factory: Any, +) -> None: + """Test new metrics raise exception for approximate search.""" + test_vector = [1.0, 2.0, 3.0] + + for strategy in [ + DistanceStrategy.JACCARD, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.MAX_INNER_PRODUCT, + ]: + vector_store = arango_vector_factory(distance_strategy=strategy) + + expected_error = ( + f"Unsupported metric: {strategy} is not supported for approximate search" + ) + with pytest.raises( + ValueError, + match=expected_error, + ): + vector_store._build_vector_search_query( + embedding=test_vector, + k=1, + return_fields={"text"}, + use_approx=True, + filter_clause="", + metadata_clause="", + )