diff --git a/examples/c/CMakeLists.txt b/examples/c/CMakeLists.txt index 126ee5db1..5d06f02b6 100644 --- a/examples/c/CMakeLists.txt +++ b/examples/c/CMakeLists.txt @@ -123,4 +123,4 @@ if(CMAKE_BUILD_TYPE STREQUAL "Release" AND ANDROID) set_property(TARGET c_api_basic_example c_api_collection_schema_example c_api_doc_example c_api_index_example c_api_field_schema_example c_api_optimized_example PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) -endif() \ No newline at end of file +endif() diff --git a/python/tests/detail/test_collection_dql.py b/python/tests/detail/test_collection_dql.py index b3749375a..bccce6154 100644 --- a/python/tests/detail/test_collection_dql.py +++ b/python/tests/detail/test_collection_dql.py @@ -843,7 +843,7 @@ def test_query_multivector_rrf(self, full_collection: Collection, doc_num): ) expected_score = expected_rrf_scores[doc_id] actual_score = doc.score - assert abs(actual_score - expected_score) < 1e-10, ( + assert abs(actual_score - expected_score) < 1e-6, ( f"RRF score mismatch for document {doc_id}: expected {expected_score}, got {actual_score}" ) assert doc.score <= prev_score, ( @@ -876,9 +876,8 @@ def test_query_multivector_weighted( batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) - weighted_reranker = WeightedReRanker( - topn=3, weights=weights, metric=MetricType.IP - ) + metrics = {field: MetricType.IP for field in weights} + weighted_reranker = WeightedReRanker(topn=3, weights=weights, metrics=metrics) single_query_results = {} for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): @@ -911,7 +910,7 @@ def test_query_multivector_weighted( ) expected_score = expected_weighted_scores[doc_id] actual_score = doc.score - assert abs(actual_score - expected_score) < 1e-10, ( + assert abs(actual_score - expected_score) < 1e-6, ( f"score mismatch for document {doc_id}: expected {expected_score}, got {actual_score}" ) assert doc.score <= prev_score, ( diff --git a/python/tests/test_collection.py b/python/tests/test_collection.py index 9b84eb723..cf94a4ecf 100644 --- a/python/tests/test_collection.py +++ b/python/tests/test_collection.py @@ -27,11 +27,17 @@ InvertIndexParam, LogLevel, LogType, + MetricType, OptimizeOption, StatusCode, Query, VectorSchema, ) +from zvec.extension.multi_vector_reranker import ( + CallbackReRanker, + RrfReRanker, + WeightedReRanker, +) # ==================== Common ==================== @@ -60,9 +66,18 @@ def collection_schema(): dimension=128, index_param=HnswIndexParam(), ), + VectorSchema( + "dense2", + DataType.VECTOR_FP32, + dimension=128, + index_param=HnswIndexParam(), + ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), + VectorSchema( + "sparse2", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() + ), ], ) @@ -78,7 +93,12 @@ def single_doc(): return Doc( id=f"{id}", fields={"id": id, "name": "test", "weight": 80.0, "height": id + 140}, - vectors={"dense": [id + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [id + 0.1] * 128, + "dense2": [id + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) @@ -88,7 +108,12 @@ def multiple_docs(): Doc( id=f"{id}", fields={"id": id, "name": "test", "weight": 80.0, "height": 210}, - vectors={"dense": [id + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [id + 0.1] * 128, + "dense2": [id + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) for id in range(1, 101) ] @@ -182,9 +207,11 @@ def test_collection_stats(self, test_collection: Collection): assert test_collection.stats is not None stats = test_collection.stats assert stats.doc_count == 0 - assert len(stats.index_completeness) == 2 + assert len(stats.index_completeness) == 4 assert stats.index_completeness["dense"] == 1 + assert stats.index_completeness["dense2"] == 1 assert stats.index_completeness["sparse"] == 1 + assert stats.index_completeness["sparse2"] == 1 # ---------------------------- @@ -449,7 +476,12 @@ def test_collection_insert_with_nullable_false_field(self, test_collection): "id": 1, "name": "test", }, - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) result = test_collection.insert(doc) assert bool(result) @@ -465,7 +497,12 @@ def test_collection_insert_without_nullable_false_field(self, test_collection): # without id, name doc = Doc( id="0", - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) with pytest.raises(ValueError) as e: # ValueError: Invalid doc: field[id] is required but not provided @@ -478,7 +515,12 @@ def test_collection_insert_without_nullable_false_field(self, test_collection): fields={ "id": 1, }, - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) with pytest.raises(ValueError) as e: test_collection.insert(doc) @@ -494,7 +536,12 @@ def test_collection_insert_with_nullable_true_field(self, test_collection): "id": 1, "name": "test", }, - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) result = test_collection.insert(doc) assert bool(result) @@ -969,70 +1016,243 @@ def test_collection_query_by_id( def test_collection_query_multi_vector_with_same_field( self, collection_with_multiple_docs: Collection, multiple_docs ): - with pytest.raises(ValueError): + # Multi-vector query on same field without reranker should raise ValueError + with pytest.raises(ValueError, match="Reranker is required"): collection_with_multiple_docs.query( [ Query(field_name="dense", vector=multiple_docs[0].vector("dense")), - Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense", vector=multiple_docs[1].vector("dense")), ] ) - @pytest.mark.skip(reason="TODO: This test case is pending implementation") + # Same field name with reranker should also raise ValueError + reranker = RrfReRanker(topn=10, rank_constant=60) + with pytest.raises(ValueError, match="appears more than once"): + collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense", vector=multiple_docs[1].vector("dense")), + ], + topk=10, + reranker=reranker, + ) + def test_collection_query_by_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + topk=10, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + topk=10, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_dense_vector_with_filter( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + topk=10, + filter="id > 50", + ) + assert len(result) > 0 + assert len(result) <= 10 + for doc in result: + assert int(doc.id) > 50 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_sparse_vector_with_filter( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + topk=10, + filter="id > 50", + ) + assert len(result) > 0 + assert len(result) <= 10 + for doc in result: + assert int(doc.id) > 50 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_multi_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with RRF reranker on multiple dense vectors.""" + reranker = RrfReRanker(topn=10, rank_constant=60) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense2", vector=multiple_docs[0].vector("dense2")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 + # Results should have RRF-fused scores + for doc in result: + assert hasattr(doc, "score") - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_multi_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with RRF reranker on multiple sparse vectors.""" + reranker = RrfReRanker(topn=10, rank_constant=60) + result = collection_with_multiple_docs.query( + [ + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + Query( + field_name="sparse2", + vector=multiple_docs[0].vector("sparse2"), + ), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_hybrid_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with RRF reranker combining dense + sparse.""" + reranker = RrfReRanker(topn=10, rank_constant=60) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_multi_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with Weighted reranker on multiple dense vectors.""" + metrics = {"dense": MetricType.IP, "dense2": MetricType.IP} + weights = {"dense": 0.6, "dense2": 0.4} + reranker = WeightedReRanker(topn=10, metrics=metrics, weights=weights) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense2", vector=multiple_docs[0].vector("dense2")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_multi_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with Weighted reranker on multiple sparse vectors.""" + metrics = {"sparse": MetricType.IP, "sparse2": MetricType.IP} + weights = {"sparse": 0.6, "sparse2": 0.4} + reranker = WeightedReRanker(topn=10, metrics=metrics, weights=weights) + result = collection_with_multiple_docs.query( + [ + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + Query( + field_name="sparse2", + vector=multiple_docs[0].vector("sparse2"), + ), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_hybrid_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with Weighted reranker combining dense + sparse.""" + metrics = {"dense": MetricType.IP, "sparse": MetricType.IP} + weights = {"dense": 0.7, "sparse": 0.3} + reranker = WeightedReRanker(topn=10, metrics=metrics, weights=weights) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 + + def test_collection_query_with_callback_reranker_by_multi_dense_vector( + self, collection_with_multiple_docs: Collection, multiple_docs + ): + """Test multi-vector query with CallbackReRanker (Python callback via C++).""" + callback_invoked = [] + + def my_rerank_callback(query_results, topn): + callback_invoked.append(True) + all_docs = [] + for docs in query_results.values(): + all_docs.extend(docs) + seen = set() + unique_docs = [] + for doc in all_docs: + if doc.pk() not in seen: + seen.add(doc.pk()) + unique_docs.append(doc) + unique_docs.sort(key=lambda d: d.score(), reverse=True) + return unique_docs[:topn] + + reranker = CallbackReRanker(callback=my_rerank_callback, topn=10) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense2", vector=multiple_docs[0].vector("dense2")), + ], + topk=10, + reranker=reranker, + ) + assert len(callback_invoked) == 1 + assert len(result) > 0 + assert len(result) <= 10 + + def test_collection_query_with_callback_reranker_by_hybrid_vector( + self, collection_with_multiple_docs: Collection, multiple_docs + ): + """Test multi-vector query with CallbackReRanker combining dense + sparse.""" + + def my_rerank_callback(query_results, topn): + all_docs = [] + for docs in query_results.values(): + all_docs.extend(docs) + seen = set() + unique_docs = [] + for doc in all_docs: + if doc.pk() not in seen: + seen.add(doc.pk()) + unique_docs.append(doc) + unique_docs.sort(key=lambda d: d.score(), reverse=True) + return unique_docs[:topn] + + reranker = CallbackReRanker(callback=my_rerank_callback, topn=5) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + ], + topk=5, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 5 diff --git a/python/tests/test_query_executor.py b/python/tests/test_query_executor.py index 6b9b76356..5907ef09b 100644 --- a/python/tests/test_query_executor.py +++ b/python/tests/test_query_executor.py @@ -31,13 +31,16 @@ ) from zvec import ( RrfReRanker, + WeightedReRanker, HnswQueryParam, CollectionSchema, VectorSchema, DataType, + MetricType, Query, VectorQuery, ) +from zvec.extension.multi_vector_reranker import CallbackReRanker # ---------------------------- @@ -209,6 +212,37 @@ def test_properties(self): assert ctx.output_fields == output_fields assert ctx.include_vector is True + def test_properties_with_weighted_reranker(self): + queries = [Query(field_name="test")] + reranker = WeightedReRanker( + topn=10, + metrics={"test": MetricType.L2}, + weights={"test": 1.0}, + ) + + ctx = QueryContext( + topk=5, + queries=queries, + reranker=reranker, + ) + + assert ctx.reranker == reranker + assert ctx.reranker.weights == {"test": 1.0} + assert ctx.reranker.metrics == {"test": MetricType.L2} + + def test_properties_with_callback_reranker(self): + queries = [Query(field_name="test")] + cb = lambda query_results, topn: [] + reranker = CallbackReRanker(callback=cb, topn=10) + + ctx = QueryContext( + topk=5, + queries=queries, + reranker=reranker, + ) + + assert ctx.reranker == reranker + def test_core_vectors_setter(self): ctx = QueryContext(topk=10) core_vectors = [MagicMock()] @@ -304,6 +338,31 @@ def test_do_validate_multiple_queries_with_reranker(self): executor._do_validate(ctx) + def test_do_validate_multiple_queries_with_weighted_reranker(self): + schema = MockCollectionSchema() + executor = MultiVectorQueryExecutor(schema) + queries = [Query(field_name="test1"), Query(field_name="test2")] + reranker = WeightedReRanker( + topn=10, + metrics={"test1": MetricType.L2, "test2": MetricType.L2}, + weights={"test1": 0.7, "test2": 0.3}, + ) + ctx = QueryContext(topk=10, queries=queries, reranker=reranker) + + executor._do_validate(ctx) + + def test_do_validate_multiple_queries_with_callback_reranker(self): + schema = MockCollectionSchema() + executor = MultiVectorQueryExecutor(schema) + queries = [Query(field_name="test1"), Query(field_name="test2")] + reranker = CallbackReRanker( + callback=lambda query_results, topn: [], + topn=10, + ) + ctx = QueryContext(topk=10, queries=queries, reranker=reranker) + + executor._do_validate(ctx) + class TestQueryExecutorFactory: def test_create_no_vectors(self): diff --git a/python/tests/test_reranker.py b/python/tests/test_reranker.py index dced1dd71..044c042bb 100644 --- a/python/tests/test_reranker.py +++ b/python/tests/test_reranker.py @@ -20,6 +20,7 @@ from zvec import Doc, MetricType from zvec.extension.multi_vector_reranker import ( + CallbackReRanker, RrfReRanker, WeightedReRanker, ) @@ -75,16 +76,17 @@ def test_rerank(self): # ---------------------------- class TestWeightedReRanker: def test_init(self): + metrics = {"vector1": MetricType.L2, "vector2": MetricType.COSINE} weights = {"vector1": 0.7, "vector2": 0.3} reranker = WeightedReRanker( topn=5, rerank_field="content", - metric=MetricType.L2, + metrics=metrics, weights=weights, ) assert reranker.topn == 5 assert reranker.rerank_field == "content" - assert reranker.metric == MetricType.L2 + assert reranker.metrics == metrics assert reranker.weights == weights def test_normalize_score(self): @@ -106,8 +108,9 @@ def test_normalize_score(self): reranker._normalize_score(1.0, "unsupported_metric") def test_rerank(self): + metrics = {"vector1": MetricType.L2, "vector2": MetricType.L2} weights = {"vector1": 0.7, "vector2": 0.3} - reranker = WeightedReRanker(topn=3, weights=weights, metric=MetricType.L2) + reranker = WeightedReRanker(topn=3, weights=weights, metrics=metrics) doc1 = Doc(id="1", score=0.8) doc2 = Doc(id="2", score=0.7) @@ -122,9 +125,70 @@ def test_rerank(self): for doc in results: assert hasattr(doc, "score") + def test_rerank_missing_metric_raises(self): + metrics = {"vector1": MetricType.L2} + reranker = WeightedReRanker(topn=3, metrics=metrics) + + doc1 = Doc(id="1", score=0.8) + query_results = {"vector1": [doc1], "vector2": [doc1]} + + with pytest.raises(ValueError, match="no metric type specified"): + reranker.rerank(query_results) + + +# ---------------------------- +# CallbackReRanker Test Case +# ---------------------------- +class TestCallbackReRanker: + def test_init(self): + def my_callback(query_results, topn): + return [] + + reranker = CallbackReRanker(callback=my_callback, topn=5) + assert reranker.topn == 5 + + def test_rerank(self): + def my_callback(query_results, topn): + all_docs = [] + for docs in query_results.values(): + all_docs.extend(docs) + all_docs.sort(key=lambda d: d.score, reverse=True) + return all_docs[:topn] + + reranker = CallbackReRanker(callback=my_callback, topn=3) + + doc1 = Doc(id="1", score=0.8) + doc2 = Doc(id="2", score=0.9) + doc3 = Doc(id="3", score=0.7) + doc4 = Doc(id="4", score=0.6) + + query_results = {"vector1": [doc1, doc2], "vector2": [doc3, doc4]} + + results = reranker.rerank(query_results) + + assert len(results) == 3 scores = [doc.score for doc in results] assert scores == sorted(scores, reverse=True) + def test_callback_with_topn(self): + received_topn = [] + + def my_callback(query_results, topn): + received_topn.append(topn) + return [] + + reranker = CallbackReRanker(callback=my_callback, topn=7) + reranker.rerank({"v1": [Doc(id="1", score=0.5)]}) + + assert received_topn == [7] + + def test_get_object_returns_cpp_reranker(self): + def my_callback(query_results, topn): + return [] + + reranker = CallbackReRanker(callback=my_callback) + assert reranker._get_object() is not None + # ---------------------------- # QwenReRanker Test Case diff --git a/python/zvec/executor/query_executor.py b/python/zvec/executor/query_executor.py index 3e54e37d2..5cb196542 100644 --- a/python/zvec/executor/query_executor.py +++ b/python/zvec/executor/query_executor.py @@ -19,8 +19,8 @@ from typing import Optional, Union, final import numpy as np -from _zvec import _Collection -from _zvec.param import _VectorQuery +from _zvec import _Collection, _MultiQuery +from _zvec.param import _SubQuery, _VectorQuery from ..extension import ReRanker, RrfReRanker, WeightedReRanker from ..model.convert import convert_to_py_doc @@ -290,11 +290,44 @@ def _do_validate(self, ctx: QueryContext) -> None: raise ValueError(f"Query field name '{field}' appears more than once") seen_fields.add(field) + def execute(self, ctx: QueryContext, collection: _Collection) -> list[Doc]: + # 1. validate query + self._do_validate(ctx) + # 2. build query vectors + query_vectors = self._do_build(ctx, collection) + if not query_vectors: + raise ValueError("No query to execute") + + # Fast path: use C++ MultiQuery for multi-vector with C++ reranker + if len(query_vectors) > 1 and ctx.reranker is not None: + cpp_reranker = ctx.reranker._get_object() + if cpp_reranker is not None: + mvq = _MultiQuery() + mvq.queries = [self._to_sub_query(vq) for vq in query_vectors] + mvq.topk = ctx.topk + if ctx.filter: + mvq.filter = ctx.filter + mvq.include_vector = ctx.include_vector + if ctx.output_fields: + mvq.output_fields = ctx.output_fields + mvq.reranker = cpp_reranker + docs = collection.Query(mvq) + return [convert_to_py_doc(doc, self._schema) for doc in docs] + + # 3. execute query (fallback to Python path) + docs = self._do_execute(query_vectors, collection) + # 4. merge and rerank result + return self._do_merge_rerank_results(ctx, docs) + def _do_execute( self, vectors: list[_VectorQuery], collection: _Collection ) -> dict[str, list[Doc]]: return super()._do_execute(vectors, collection) + @staticmethod + def _to_sub_query(vq: _VectorQuery) -> _SubQuery: + return _SubQuery.from_vector_query(vq) + class QueryExecutorFactory: @staticmethod diff --git a/python/zvec/extension/__init__.py b/python/zvec/extension/__init__.py index 9ff94af29..f738c6c90 100644 --- a/python/zvec/extension/__init__.py +++ b/python/zvec/extension/__init__.py @@ -18,7 +18,7 @@ from .http_embedding_function import HTTPDenseEmbedding from .jina_embedding_function import JinaDenseEmbedding from .jina_function import JinaFunctionBase -from .multi_vector_reranker import RrfReRanker, WeightedReRanker +from .multi_vector_reranker import CallbackReRanker, RrfReRanker, WeightedReRanker from .openai_embedding_function import OpenAIDenseEmbedding from .openai_function import OpenAIFunctionBase from .qwen_embedding_function import QwenDenseEmbedding, QwenSparseEmbedding @@ -34,6 +34,7 @@ __all__ = [ "BM25EmbeddingFunction", + "CallbackReRanker", "DefaultLocalDenseEmbedding", "DefaultLocalReRanker", "DefaultLocalSparseEmbedding", diff --git a/python/zvec/extension/multi_vector_reranker.py b/python/zvec/extension/multi_vector_reranker.py index ba3a2363b..ecb640740 100644 --- a/python/zvec/extension/multi_vector_reranker.py +++ b/python/zvec/extension/multi_vector_reranker.py @@ -16,8 +16,11 @@ import heapq import math from collections import defaultdict +from collections.abc import Callable from typing import Optional +from _zvec import _CallbackReranker, _RrfReranker, _WeightedReranker + from ..model.doc import Doc from ..typing import MetricType from .rerank_function import RerankFunction @@ -51,11 +54,17 @@ def __init__( ): super().__init__(topn=topn, rerank_field=rerank_field) self._rank_constant = rank_constant + # Use C++ implementation for performance + self._cpp_reranker = _RrfReranker(rank_constant) @property def rank_constant(self) -> int: return self._rank_constant + def _get_object(self): + """Return the underlying C++ RrfReranker instance.""" + return self._cpp_reranker + def _rrf_score(self, rank: int) -> float: return 1.0 / (self._rank_constant + rank + 1) @@ -91,8 +100,9 @@ def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: class WeightedReRanker(RerankFunction): """Re-ranker that combines scores from multiple vector fields using weights. - Each vector field's relevance score is normalized based on its metric type, - then scaled by a user-provided weight. Final scores are summed across fields. + Each vector field's relevance score is normalized based on its own metric + type, then scaled by a user-provided weight. Final scores are summed across + fields. Note: This re-ranker is specifically designed for multi-vector scenarios where @@ -102,8 +112,10 @@ class WeightedReRanker(RerankFunction): Args: topn (int, optional): Number of top documents to return. Defaults to 10. rerank_field (Optional[str], optional): Ignored. Defaults to None. - metric (MetricType, optional): Distance metric used for score normalization. - Defaults to ``MetricType.L2``. + metrics (Optional[dict[str, MetricType]], optional): Per-field distance + metric used for score normalization. Every queried field must have + a metric specified; missing fields will raise an error at rerank time. + Defaults to None. weights (Optional[dict[str, float]], optional): Weight per vector field. Fields not listed use weight 1.0. Defaults to None. @@ -115,12 +127,12 @@ def __init__( self, topn: int = 10, rerank_field: Optional[str] = None, - metric: MetricType = MetricType.L2, + metrics: Optional[dict[str, MetricType]] = None, weights: Optional[dict[str, float]] = None, ): super().__init__(topn=topn, rerank_field=rerank_field) self._weights = weights or {} - self._metric = metric + self._metrics = metrics or {} @property def weights(self) -> dict[str, float]: @@ -128,9 +140,13 @@ def weights(self) -> dict[str, float]: return self._weights @property - def metric(self) -> MetricType: - """MetricType: Distance metric used for score normalization.""" - return self._metric + def metrics(self) -> dict[str, MetricType]: + """dict[str, MetricType]: Per-field metric type mapping.""" + return self._metrics + + def _get_object(self): + """Return a C++ WeightedReranker instance.""" + return _WeightedReranker(self._weights) def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: """Combine scores from multiple vector fields using weighted sum. @@ -145,10 +161,16 @@ def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: id_to_doc: dict[str, Doc] = {} for vector_name, query_result in query_results.items(): + if vector_name not in self._metrics: + raise ValueError( + f"WeightedReRanker: no metric type specified for field " + f"'{vector_name}'" + ) + metric = self._metrics[vector_name] for _, doc in enumerate(query_result): doc_id = doc.id weighted_score = self._normalize_score( - doc.score, self.metric + doc.score, metric ) * self.weights.get(vector_name, 1.0) weighted_scores[doc_id] += weighted_score if doc_id not in id_to_doc: @@ -172,3 +194,43 @@ def _normalize_score(self, score: float, metric: MetricType) -> float: if metric == MetricType.COSINE: return 1.0 - score / 2.0 raise ValueError("Unsupported metric type") + + +class CallbackReRanker(RerankFunction): + """Re-ranker that delegates to a user-provided Python callback. + + This bridges a Python callable into the C++ reranker interface, enabling + custom re-ranking logic to be executed within the C++ MultiQuery path. + + The callback receives the raw C++ Doc objects (as ``_Doc`` instances) grouped + by vector field name, and must return a list of ``_Doc`` instances. + + Args: + callback: A callable with signature + ``(query_results: dict[str, list[_Doc]], topn: int) -> list[_Doc]``. + topn (int, optional): Number of top documents to return. Defaults to 10. + """ + + def __init__( + self, + callback: Callable, + topn: int = 10, + ): + super().__init__(topn=topn) + self._callback = callback + self._cpp_reranker = _CallbackReranker(callback) + + def _get_object(self): + """Return the underlying C++ CallbackReranker instance.""" + return self._cpp_reranker + + def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: + """Invoke the callback to re-rank documents. + + Args: + query_results (dict[str, list[Doc]]): Results per vector field. + + Returns: + list[Doc]: Re-ranked documents. + """ + return self._callback(query_results, self.topn) diff --git a/python/zvec/extension/rerank_function.py b/python/zvec/extension/rerank_function.py index c558a2bc4..0d8d00263 100644 --- a/python/zvec/extension/rerank_function.py +++ b/python/zvec/extension/rerank_function.py @@ -67,3 +67,15 @@ def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: with updated ``score`` fields. """ ... + + def _get_object(self): + """Return the underlying C++ Reranker instance, if available. + + This is used internally by the query executor to pass the reranker + to the C++ MultiQuery method. Subclasses that wrap a C++ reranker + should override this method. + + Returns: + The C++ Reranker shared pointer, or None if not available. + """ + return None # noqa: RET501 diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index b23c7ecd8..170bdd9e0 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -27,12 +27,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -5340,6 +5342,363 @@ zvec_error_code_t zvec_group_by_vector_query_set_flat_params( return ZVEC_OK; } +// ============================================================================= +// Reranker Implementation +// ============================================================================= + +zvec_reranker_t *zvec_reranker_create_rrf(int rank_constant) { + ZVEC_TRY_RETURN_NULL("Failed to create RRF Reranker", + auto *reranker = + new zvec::Reranker::Ptr( + std::make_shared( + rank_constant)); + return reinterpret_cast(reranker);) + return nullptr; +} + +zvec_reranker_t *zvec_reranker_create_weighted(const char **fields, + const double *weights, + size_t field_count) { + if ((!fields || !weights) && field_count > 0) { + set_last_error( + "Fields and weights pointers cannot be null when field_count > 0"); + return nullptr; + } + + ZVEC_TRY_RETURN_NULL( + "Failed to create Weighted Reranker", + std::map weight_map; + for (size_t i = 0; i < field_count; ++i) { + if (!fields[i]) { + set_last_error("Null field name at index " + std::to_string(i)); + return nullptr; + } + weight_map[fields[i]] = weights[i]; + } + + auto *reranker = new zvec::Reranker::Ptr( + std::make_shared(weight_map)); + return reinterpret_cast(reranker);) + return nullptr; +} + +void zvec_reranker_destroy(zvec_reranker_t *reranker) { + if (reranker) { + delete reinterpret_cast(reranker); + } +} + +int zvec_reranker_get_rank_constant(const zvec_reranker_t *reranker) { + if (!reranker) return -1; + auto *ptr = reinterpret_cast(reranker); + auto *rrf = dynamic_cast(ptr->get()); + return rrf ? rrf->rank_constant() : -1; +} + +// ============================================================================= +// MultiVectorQuery Implementation +// ============================================================================= + +zvec_multi_query_t *zvec_multi_query_create(void) { + ZVEC_TRY_RETURN_NULL("Failed to create MultiVectorQuery", + auto *query = new zvec::MultiQuery(); + return reinterpret_cast( + query);) + return nullptr; +} + +void zvec_multi_query_destroy(zvec_multi_query_t *query) { + if (query) { + delete reinterpret_cast(query); + } +} + +zvec_error_code_t zvec_multi_query_add_sub_query( + zvec_multi_query_t *query, + const zvec_sub_query_t *sub_query) { + if (!query || !sub_query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or sub_query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + auto *sub = reinterpret_cast(sub_query); + mvq->queries.push_back(*sub); + + return ZVEC_OK; +} + +size_t zvec_multi_query_get_sub_query_count( + const zvec_multi_query_t *query) { + if (!query) return 0; + auto *mvq = reinterpret_cast(query); + return mvq->queries.size(); +} + +zvec_error_code_t zvec_multi_query_set_topk( + zvec_multi_query_t *query, int topk) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Multi-vector query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *mvq = reinterpret_cast(query); + mvq->topk = topk; + return ZVEC_OK; +} + +int zvec_multi_query_get_topk( + const zvec_multi_query_t *query) { + if (!query) return 0; + auto *mvq = reinterpret_cast(query); + return mvq->topk; +} + +zvec_error_code_t zvec_multi_query_set_filter( + zvec_multi_query_t *query, const char *filter) { + if (!query || !filter) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or filter pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *mvq = reinterpret_cast(query); + mvq->filter = std::string(filter); + return ZVEC_OK; +} + +const char *zvec_multi_query_get_filter( + const zvec_multi_query_t *query) { + if (!query) return nullptr; + auto *mvq = reinterpret_cast(query); + return mvq->filter.c_str(); +} + +zvec_error_code_t zvec_multi_query_set_include_vector( + zvec_multi_query_t *query, bool include) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Multi-vector query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *mvq = reinterpret_cast(query); + mvq->include_vector = include; + return ZVEC_OK; +} + +bool zvec_multi_query_get_include_vector( + const zvec_multi_query_t *query) { + if (!query) return false; + auto *mvq = reinterpret_cast(query); + return mvq->include_vector; +} + +zvec_error_code_t zvec_multi_query_set_output_fields( + zvec_multi_query_t *query, const char **fields, size_t count) { + if (!query || (!fields && count > 0)) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query pointer is null or fields is null with count > 0"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + std::vector field_vec; + field_vec.reserve(count); + for (size_t i = 0; i < count; ++i) { + if (!fields[i]) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Null field name at index " + std::to_string(i)); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + field_vec.emplace_back(fields[i]); + } + mvq->output_fields = std::move(field_vec); + + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_query_get_output_fields( + zvec_multi_query_t *query, const char ***fields, size_t *count) { + if (!query || !fields || !count) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query, fields or count pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + if (!mvq->output_fields.has_value() || mvq->output_fields->empty()) { + *fields = nullptr; + *count = 0; + return ZVEC_OK; + } + + const auto &field_vec = mvq->output_fields.value(); + *count = field_vec.size(); + *fields = static_cast(malloc(*count * sizeof(const char *))); + if (!*fields) { + SET_LAST_ERROR(ZVEC_ERROR_INTERNAL_ERROR, "Failed to allocate memory"); + return ZVEC_ERROR_INTERNAL_ERROR; + } + for (size_t i = 0; i < *count; ++i) { + (*fields)[i] = field_vec[i].c_str(); + } + + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_query_set_reranker( + zvec_multi_query_t *query, zvec_reranker_t *reranker) { + if (!query || !reranker) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or reranker pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + auto *reranker_ptr = + reinterpret_cast(reranker); + mvq->reranker = *reranker_ptr; + + return ZVEC_OK; +} + +// ============================================================================= +// SubVectorQuery Implementation +// ============================================================================= + +zvec_sub_query_t *zvec_sub_query_create(void) { + ZVEC_TRY_RETURN_NULL("Failed to create SubVectorQuery", + auto *query = new zvec::SubQuery(); + query->num_candidates_ = 10; + return reinterpret_cast( + query);) + return nullptr; +} + +void zvec_sub_query_destroy(zvec_sub_query_t *query) { + if (query) { + delete reinterpret_cast(query); + } +} + +zvec_error_code_t zvec_sub_query_set_num_candidates( + zvec_sub_query_t *query, int num_candidates) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + ptr->num_candidates_ = num_candidates; + return ZVEC_OK; +} + +int zvec_sub_query_get_num_candidates( + const zvec_sub_query_t *query) { + if (!query) return 0; + auto *ptr = reinterpret_cast(query); + return ptr->num_candidates_; +} + +zvec_error_code_t zvec_sub_query_set_field_name( + zvec_sub_query_t *query, const char *field_name) { + if (!query || !field_name) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or field_name pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + ptr->target_.field_name_ = std::string(field_name); + return ZVEC_OK; +} + +const char *zvec_sub_query_get_field_name( + const zvec_sub_query_t *query) { + if (!query) return nullptr; + auto *ptr = reinterpret_cast(query); + return ptr->target_.field_name_.c_str(); +} + +zvec_error_code_t zvec_sub_query_set_query_vector( + zvec_sub_query_t *query, const void *data, size_t size) { + if (!query || !data || size == 0) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query pointer or data is null/empty"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto &payload = std::get(ptr->target_.clause_); + payload.query_vector_.assign(static_cast(data), size); + return ZVEC_OK; +} + +zvec_error_code_t zvec_sub_query_set_sparse_indices( + zvec_sub_query_t *query, const uint32_t *indices, size_t count) { + if (!query || (!indices && count > 0)) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or indices pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto &payload = std::get(ptr->target_.clause_); + payload.sparse_indices_.assign( + reinterpret_cast(indices), count * sizeof(uint32_t)); + return ZVEC_OK; +} + +zvec_error_code_t zvec_sub_query_set_sparse_values( + zvec_sub_query_t *query, const float *values, size_t count) { + if (!query || (!values && count > 0)) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or values pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto &payload = std::get(ptr->target_.clause_); + payload.sparse_values_.assign( + reinterpret_cast(values), count * sizeof(float)); + return ZVEC_OK; +} + +zvec_error_code_t zvec_sub_query_set_hnsw_params( + zvec_sub_query_t *query, zvec_hnsw_query_params_t *hnsw_params) { + if (!query || !hnsw_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or HNSW params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(hnsw_params); + ptr->target_.query_params_.reset(params_ptr); + return ZVEC_OK; +} + +zvec_error_code_t zvec_sub_query_set_ivf_params( + zvec_sub_query_t *query, zvec_ivf_query_params_t *ivf_params) { + if (!query || !ivf_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or IVF params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(ivf_params); + ptr->target_.query_params_.reset(params_ptr); + return ZVEC_OK; +} + +zvec_error_code_t zvec_sub_query_set_flat_params( + zvec_sub_query_t *query, zvec_flat_query_params_t *flat_params) { + if (!query || !flat_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or Flat params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(flat_params); + ptr->target_.query_params_.reset(params_ptr); + return ZVEC_OK; +} + // ============================================================================= // Index Interface Implementation // ============================================================================= @@ -5998,6 +6357,41 @@ zvec_error_code_t zvec_collection_query(const zvec_collection_t *collection, return error_code;) } +zvec_error_code_t zvec_collection_multi_query( + const zvec_collection_t *collection, + const zvec_multi_query_t *query, + zvec_doc_t ***results, size_t *result_count) { + if (!collection || !query || !results || !result_count) { + set_last_error( + "Invalid arguments: collection, query, results and result_count " + "cannot be null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + ZVEC_TRY_RETURN_ERROR( + "Exception occurred", + auto coll_ptr = + reinterpret_cast *>( + collection); + + auto *internal_query = + reinterpret_cast(query); + + auto result = (*coll_ptr)->Query(*internal_query); + zvec_error_code_t error_code = handle_expected_result(result); + + if (error_code == ZVEC_OK) { + const auto &query_results = result.value(); + error_code = + convert_document_results(query_results, results, result_count); + } else { + *results = nullptr; + *result_count = 0; + } + + return error_code;) +} + zvec_error_code_t zvec_collection_fetch(zvec_collection_t *collection, const char *const *pks, size_t pk_count, const char *const *output_fields, diff --git a/src/binding/python/CMakeLists.txt b/src/binding/python/CMakeLists.txt index d17f56289..a02287c8d 100644 --- a/src/binding/python/CMakeLists.txt +++ b/src/binding/python/CMakeLists.txt @@ -10,6 +10,7 @@ set(SRC_LISTS binding.cc model/python_collection.cc model/python_doc.cc + model/python_reranker.cc model/param/python_param.cc model/schema/python_schema.cc model/common/python_config.cc diff --git a/src/binding/python/binding.cc b/src/binding/python/binding.cc index ed8d6918d..c1bdad367 100644 --- a/src/binding/python/binding.cc +++ b/src/binding/python/binding.cc @@ -16,6 +16,7 @@ #include "python_config.h" #include "python_doc.h" #include "python_param.h" +#include "python_reranker.h" #include "python_schema.h" #include "python_type.h" @@ -26,6 +27,7 @@ PYBIND11_MODULE(_zvec, m) { ZVecPyTyping::Initialize(m); ZVecPyParams::Initialize(m); ZVecPySchemas::Initialize(m); + ZVecPyReranker::Initialize(m); ZVecPyConfig::Initialize(m); ZVecPyDoc::Initialize(m); ZVecPyCollection::Initialize(m); diff --git a/src/binding/python/include/python_reranker.h b/src/binding/python/include/python_reranker.h new file mode 100644 index 000000000..4ab174a62 --- /dev/null +++ b/src/binding/python/include/python_reranker.h @@ -0,0 +1,31 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +namespace py = pybind11; + +namespace zvec { + +class ZVecPyReranker { + public: + ZVecPyReranker() = delete; + + public: + static void Initialize(py::module_ &m); +}; + +} // namespace zvec diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 268246cd6..86376db45 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include "python_doc.h" namespace zvec { @@ -1372,6 +1373,25 @@ Constructs an AlterColumnOption instance. } void ZVecPyParams::bind_vector_query(py::module_ &m) { + // Bind SubQuery (used by MultiQuery) + py::class_(m, "_SubQuery") + .def(py::init<>()) + .def_readwrite("num_candidates", &SubQuery::num_candidates_) + .def_static( + "from_vector_query", + [](const VectorQuery &vq) { + SubQuery sub; + sub.num_candidates_ = vq.topk_; + sub.target_.field_name_ = vq.field_name_; + auto &clause = std::get(sub.target_.clause_); + clause.query_vector_ = vq.query_vector_; + clause.sparse_indices_ = vq.query_sparse_indices_; + clause.sparse_values_ = vq.query_sparse_values_; + sub.target_.query_params_ = vq.query_params_; + return sub; + }, + py::arg("vector_query"), "Create a SubQuery from a VectorQuery"); + py::class_(m, "_VectorQuery") .def(py::init<>()) // properties diff --git a/src/binding/python/model/python_collection.cc b/src/binding/python/model/python_collection.cc index d5e6b3203..d902e028b 100644 --- a/src/binding/python/model/python_collection.cc +++ b/src/binding/python/model/python_collection.cc @@ -261,6 +261,19 @@ void ZVecPyCollection::bind_dql_methods( // return DocPtrList return unwrap_expected(result); }) + // MultiQuery: multi query with reranker + .def( + "Query", + [](const Collection &self, const MultiQuery &query) { + Result result; + { + py::gil_scoped_release release; + result = self.Query(query); + } + // return DocPtrList + return unwrap_expected(result); + }, + py::arg("query"), "Execute a multi query with re-ranking.") .def("GroupByQuery", [](const Collection &self, const GroupByVectorQuery &query) { Result result; diff --git a/src/binding/python/model/python_reranker.cc b/src/binding/python/model/python_reranker.cc new file mode 100644 index 000000000..e8a33741a --- /dev/null +++ b/src/binding/python/model/python_reranker.cc @@ -0,0 +1,59 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "python_reranker.h" +#include +#include +#include +#include + +namespace zvec { + +void ZVecPyReranker::Initialize(py::module_ &m) { + // Bind Reranker base class (abstract, cannot be instantiated directly) + py::class_(m, "_Reranker"); + + // Bind ScoreBasedReranker intermediate class + py::class_>( + m, "_ScoreBasedReranker"); + + // Bind RrfReranker + py::class_>( + m, "_RrfReranker") + .def(py::init(), py::arg("rank_constant") = 60) + .def_property_readonly("rank_constant", &RrfReranker::rank_constant); + + // Bind WeightedReranker + py::class_>(m, "_WeightedReranker") + .def(py::init>(), py::arg("weights")) + .def_property_readonly("weights", &WeightedReranker::weights); + + // Bind CallbackReranker + py::class_>( + m, "_CallbackReranker") + .def(py::init(), py::arg("callback")); + + // Bind MultiQuery struct + py::class_(m, "_MultiQuery") + .def(py::init<>()) + .def_readwrite("queries", &MultiQuery::queries) + .def_readwrite("topk", &MultiQuery::topk) + .def_readwrite("filter", &MultiQuery::filter) + .def_readwrite("include_vector", &MultiQuery::include_vector) + .def_readwrite("output_fields", &MultiQuery::output_fields) + .def_readwrite("reranker", &MultiQuery::reranker); +} + +} // namespace zvec diff --git a/src/db/collection.cc b/src/db/collection.cc index 36f9a7420..05bde1678 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -15,8 +15,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -29,6 +31,7 @@ #include #include #include +#include #include #include #include "db/common/constants.h" @@ -117,6 +120,8 @@ class CollectionImpl : public Collection { Result Query(const VectorQuery &query) const override; + Result Query(const MultiQuery &query) const override; + Result GroupByQuery( const GroupByVectorQuery &query) const override; @@ -1597,6 +1602,86 @@ Result CollectionImpl::Query(const VectorQuery &query) const { return sql_engine_->execute(schema_, sanitized, segments); } +Result CollectionImpl::Query(const MultiQuery &query) const { + std::shared_lock lock(schema_handle_mtx_); + + CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false); + + if (query.queries.size() < 2) { + return tl::make_unexpected( + Status::InvalidArgument("Query requires at least 2 sub-queries")); + } + + if (!query.reranker) { + return tl::make_unexpected( + Status::InvalidArgument("Reranker is required for multi-vector query")); + } + + auto segments = get_all_segments(); + if (segments.empty()) { + return DocPtrList(); + } + + // Convert SubVectorQuery to VectorQuery and validate + std::set seen_fields; + std::vector converted_queries; + converted_queries.reserve(query.queries.size()); + + for (const auto &sub : query.queries) { + const auto &target = sub.target_; + auto [_, inserted] = seen_fields.insert(target.field_name_); + if (!inserted) { + return tl::make_unexpected(Status::InvalidArgument( + "Duplicate field name in multi-vector query: ", target.field_name_)); + } + auto *field_schema = schema_->get_vector_field(target.field_name_); + if (!field_schema) { + return tl::make_unexpected(Status::InvalidArgument( + "Vector field not found: ", target.field_name_)); + } + + VectorQuery vq; + vq.topk_ = sub.num_candidates_; + vq.field_name_ = target.field_name_; + const auto &vec_clause = std::get(target.clause_); + vq.query_vector_ = vec_clause.query_vector_; + vq.query_sparse_indices_ = vec_clause.sparse_indices_; + vq.query_sparse_values_ = vec_clause.sparse_values_; + vq.query_params_ = target.query_params_; + vq.filter_ = query.filter; + vq.include_vector_ = query.include_vector; + vq.include_doc_id_ = query.include_doc_id_; + vq.output_fields_ = query.output_fields; + + auto s = vq.validate_and_sanitize(field_schema); + CHECK_RETURN_STATUS_EXPECTED(s); + converted_queries.push_back(std::move(vq)); + } + + // Execute each VectorQuery concurrently and collect results per field + std::vector>> futures; + futures.reserve(converted_queries.size()); + for (const auto &vq : converted_queries) { + futures.push_back(std::async(std::launch::async, [&]() { + auto engine = sqlengine::SQLEngine::create(std::make_shared()); + return engine->execute(schema_, vq, segments); + })); + } + + std::map query_results; + for (size_t i = 0; i < converted_queries.size(); ++i) { + auto result = futures[i].get(); + if (!result.has_value()) { + return tl::make_unexpected(result.error()); + } + query_results[converted_queries[i].field_name_] = std::move(result.value()); + } + + // Merge and rerank results + query.reranker->bind_schema(schema_); + return query.reranker->rerank(query_results, query.topk); +} + Result CollectionImpl::GroupByQuery( const GroupByVectorQuery &query) const { std::shared_lock lock(schema_handle_mtx_); diff --git a/src/db/index/common/doc.cc b/src/db/index/common/doc.cc index 0405eac1d..fe8be100e 100644 --- a/src/db/index/common/doc.cc +++ b/src/db/index/common/doc.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "db/common/constants.h" #include "db/index/common/type_helper.h" diff --git a/src/db/reranker/reranker.cc b/src/db/reranker/reranker.cc new file mode 100644 index 000000000..252e88f1a --- /dev/null +++ b/src/db/reranker/reranker.cc @@ -0,0 +1,138 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "zvec/db/status.h" +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include + +namespace zvec { + +// ==================== ScoreBasedReranker ==================== + +Result ScoreBasedReranker::rerank( + const std::map &query_results, int topn) const { + std::unordered_map scores; + std::unordered_map id_to_doc; + + for (const auto &[field_name, docs] : query_results) { + for (size_t rank = 0; rank < docs.size(); ++rank) { + const auto &doc = docs[rank]; + const std::string &doc_id = doc->pk(); + auto rs = rescore(static_cast(doc->score()), + static_cast(rank), field_name); + if (!rs.has_value()) { + return tl::make_unexpected(rs.error()); + } + scores[doc_id] += rs.value(); + if (id_to_doc.find(doc_id) == id_to_doc.end()) { + id_to_doc[doc_id] = doc; + } + } + } + + using ScorePair = std::pair; + auto cmp = [](const ScorePair &a, const ScorePair &b) { + return a.second > b.second; + }; + std::priority_queue, decltype(cmp)> pq(cmp); + + for (const auto &[doc_id, score] : scores) { + if (static_cast(pq.size()) < topn) { + pq.emplace(doc_id, score); + } else if (score > pq.top().second) { + pq.pop(); + pq.emplace(doc_id, score); + } + } + + DocPtrList results; + results.reserve(pq.size()); + while (!pq.empty()) { + const auto &[doc_id, score] = pq.top(); + auto doc = std::move(id_to_doc[doc_id]); + doc->set_score(static_cast(score)); + results.push_back(std::move(doc)); + pq.pop(); + } + std::reverse(results.begin(), results.end()); + return results; +} + +// ==================== RrfReranker ==================== + +Result RrfReranker::rescore(double /*score*/, int rank, + const std::string & /*field_name*/) const { + return 1.0 / (static_cast(rank_constant_) + + static_cast(rank) + 1.0); +} + +// ==================== WeightedReranker ==================== + +WeightedReranker::WeightedReranker(const std::map &weights) + : weights_(weights) {} + +void WeightedReranker::bind_schema(CollectionSchema::Ptr schema) { + schema_ = std::move(schema); +} + +Result WeightedReranker::normalize_score(double score, + const FieldSchema &field) { + auto *vip = + dynamic_cast(field.index_params().get()); + if (!vip) { + return tl::make_unexpected( + Status::InvalidArgument("WeightedReranker: field '", field.name(), + "' has no vector index params")); + } + switch (vip->metric_type()) { + case MetricType::L2: + return 1.0 - 2.0 * std::atan(score) / M_PI; + case MetricType::IP: + return 0.5 + std::atan(score) / M_PI; + case MetricType::COSINE: + return 1.0 - score / 2.0; + default: + return tl::make_unexpected(Status::InvalidArgument( + "Unsupported metric type for normalization: ", + std::to_string(static_cast(vip->metric_type())))); + } +} + +Result WeightedReranker::rescore(double score, int /*rank*/, + const std::string &field_name) const { + const auto *field = schema_->get_vector_field(field_name); + if (!field) { + return tl::make_unexpected(Status::InvalidArgument( + "WeightedReranker: vector field not found: '", field_name + "'")); + } + auto normalized = normalize_score(score, *field); + if (!normalized.has_value()) { + return tl::make_unexpected(normalized.error()); + } + double weight = 1.0; + auto weight_it = weights_.find(field_name); + if (weight_it != weights_.end()) { + weight = weight_it->second; + } + return normalized.value() * weight; +} + +} // namespace zvec diff --git a/src/db/sqlengine/parser/sql_info_helper.h b/src/db/sqlengine/parser/sql_info_helper.h index 465ccdce2..760dbc4e3 100644 --- a/src/db/sqlengine/parser/sql_info_helper.h +++ b/src/db/sqlengine/parser/sql_info_helper.h @@ -14,7 +14,7 @@ #pragma once -#include +#include #include "db/sqlengine/common/group_by.h" #include "db/sqlengine/parser/node.h" #include "db/sqlengine/parser/sql_info.h" diff --git a/src/db/sqlengine/sqlengine.h b/src/db/sqlengine/sqlengine.h index d86fd69bf..47143b60f 100644 --- a/src/db/sqlengine/sqlengine.h +++ b/src/db/sqlengine/sqlengine.h @@ -14,7 +14,7 @@ #pragma once -#include +#include #include #include "db/common/profiler.h" #include "db/index/segment/segment.h" diff --git a/src/db/sqlengine/sqlengine_impl.h b/src/db/sqlengine/sqlengine_impl.h index 88c279283..5e258346b 100644 --- a/src/db/sqlengine/sqlengine_impl.h +++ b/src/db/sqlengine/sqlengine_impl.h @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include "analyzer/query_info.h" #include "common/group_by.h" diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index 74cc1bfbd..187f87046 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -1032,6 +1032,37 @@ typedef struct zvec_vector_query_t zvec_vector_query_t; */ typedef struct zvec_group_by_vector_query_t zvec_group_by_vector_query_t; +/** + * @brief Document object (opaque pointer, forward declaration for reranker + * callback) + */ +typedef struct zvec_doc_t zvec_doc_t; + +/** + * @brief Reranker structure (opaque pointer) + * Aligned with zvec::Reranker + * Use zvec_reranker_create_rrf() or zvec_reranker_create_weighted() to create + * and zvec_reranker_destroy() to destroy + */ +typedef struct zvec_reranker_t zvec_reranker_t; +typedef struct zvec_collection_schema_t zvec_collection_schema_t; + +/** + * @brief Multi-query query structure (opaque pointer) + * Aligned with zvec::MultiQuery + * Use zvec_multi_query_create() to create and + * zvec_multi_query_destroy() to destroy + */ +typedef struct zvec_multi_query_t zvec_multi_query_t; + +/** + * @brief Sub-query structure for multi-query (opaque pointer) + * Aligned with zvec::SubQuery + * Use zvec_sub_query_create() to create and + * zvec_sub_query_destroy() to destroy + */ +typedef struct zvec_sub_query_t zvec_sub_query_t; + // ============================================================================= // Query Parameters Management Functions @@ -1704,6 +1735,268 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_group_by_vector_query_set_flat_params( zvec_group_by_vector_query_t *query, zvec_flat_query_params_t *flat_params); +// ----------------------------------------------------------------------------- +// zvec_reranker_t (Reranker) +// ----------------------------------------------------------------------------- + +/** + * @brief Create an RRF (Reciprocal Rank Fusion) reranker + * @param rank_constant RRF rank constant (default: 60) + * @return zvec_reranker_t* Pointer to the newly created reranker + */ +ZVEC_EXPORT zvec_reranker_t *ZVEC_CALL +zvec_reranker_create_rrf(int rank_constant); + +/** + * @brief Create a Weighted reranker + * @param fields Array of field names + * @param weights Array of weights corresponding to fields + * @param field_count Number of field/weight entries + * @return zvec_reranker_t* Pointer to the newly created reranker + */ +ZVEC_EXPORT zvec_reranker_t *ZVEC_CALL zvec_reranker_create_weighted( + const char **fields, const double *weights, size_t field_count); + +/** + * @brief Destroy reranker + * @param reranker Reranker pointer + */ +ZVEC_EXPORT void ZVEC_CALL zvec_reranker_destroy(zvec_reranker_t *reranker); + +/** + * @brief Get RRF rank constant (only valid for RRF reranker) + * @param reranker Reranker pointer + * @return int Rank constant, or -1 if not an RRF reranker + */ +ZVEC_EXPORT int ZVEC_CALL +zvec_reranker_get_rank_constant(const zvec_reranker_t *reranker); + +// ----------------------------------------------------------------------------- +// zvec_multi_query_t (Multi Query) +// ----------------------------------------------------------------------------- + +/** + * @brief Create multi-query query + * @return zvec_multi_query_t* Pointer to the newly created multi query + */ +ZVEC_EXPORT zvec_multi_query_t *ZVEC_CALL zvec_multi_query_create(void); + +/** + * @brief Destroy multi-query query + * @param query Multi query pointer + */ +ZVEC_EXPORT void ZVEC_CALL zvec_multi_query_destroy(zvec_multi_query_t *query); + +/** + * @brief Add a sub-query to the multi-query query + * @param query Multi query pointer + * @param sub_query Sub-query to add (copied, caller retains ownership) + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_add_sub_query( + zvec_multi_query_t *query, const zvec_sub_query_t *sub_query); + +/** + * @brief Get number of sub-queries + * @param query Multi query pointer + * @return size_t Number of sub-queries + */ +ZVEC_EXPORT size_t ZVEC_CALL +zvec_multi_query_get_sub_query_count(const zvec_multi_query_t *query); + +/** + * @brief Set topk + * @param query Multi-vector query pointer + * @param topk Number of results + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_query_set_topk(zvec_multi_query_t *query, int topk); + +/** + * @brief Get topk + * @param query Multi-vector query pointer + * @return int Number of results + */ +ZVEC_EXPORT int ZVEC_CALL +zvec_multi_query_get_topk(const zvec_multi_query_t *query); + +/** + * @brief Set filter expression + * @param query Multi-vector query pointer + * @param filter Filter expression string + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_query_set_filter(zvec_multi_query_t *query, const char *filter); + +/** + * @brief Get filter expression + * @param query Multi-vector query pointer + * @return const char* Filter expression (owned by query, do not free) + */ +ZVEC_EXPORT const char *ZVEC_CALL +zvec_multi_query_get_filter(const zvec_multi_query_t *query); + +/** + * @brief Set whether to include vector data in results + * @param query Multi-vector query pointer + * @param include Whether to include vectors + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_query_set_include_vector(zvec_multi_query_t *query, bool include); + +/** + * @brief Get whether to include vector data in results + * @param query Multi-vector query pointer + * @return bool Whether to include vectors + */ +ZVEC_EXPORT bool ZVEC_CALL +zvec_multi_query_get_include_vector(const zvec_multi_query_t *query); + +/** + * @brief Set output fields + * @param query Multi-vector query pointer + * @param fields Array of field names + * @param count Number of fields + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_set_output_fields( + zvec_multi_query_t *query, const char **fields, size_t count); + +/** + * @brief Get output fields + * @param query Multi-vector query pointer + * @param[out] fields Output array of field names (allocated by library) + * @param[out] count Number of fields + * @return zvec_error_code_t Error code + * + * @note The returned array is allocated by the library and should be freed + * using zvec_free() when no longer needed. The individual string pointers + * are owned by the query and must NOT be freed. + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_get_output_fields( + zvec_multi_query_t *query, const char ***fields, size_t *count); + +/** + * @brief Set reranker (copies shared pointer, caller must still destroy + * reranker) + * @param query Multi-vector query pointer + * @param reranker Reranker pointer (remains valid, caller must call + * zvec_reranker_destroy after use) + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_set_reranker( + zvec_multi_query_t *query, zvec_reranker_t *reranker); + +// ----------------------------------------------------------------------------- +// zvec_sub_query_t (Sub-Query for Multi Query) +// ----------------------------------------------------------------------------- + +/** + * @brief Create sub-query + * @return zvec_sub_query_t* Pointer to the newly created + * sub-query + */ +ZVEC_EXPORT zvec_sub_query_t *ZVEC_CALL zvec_sub_query_create(void); + +/** + * @brief Destroy sub-query + * @param query Sub-query pointer + */ +ZVEC_EXPORT void ZVEC_CALL zvec_sub_query_destroy(zvec_sub_query_t *query); + +/** + * @brief Set number of candidates to retrieve per field + * @param query Sub-query pointer + * @param num_candidates Number of candidates + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_sub_query_set_num_candidates(zvec_sub_query_t *query, int num_candidates); + +/** + * @brief Get number of candidates + * @param query Sub-query pointer + * @return int Number of candidates + */ +ZVEC_EXPORT int ZVEC_CALL +zvec_sub_query_get_num_candidates(const zvec_sub_query_t *query); + +/** + * @brief Set field name + * @param query Sub-query pointer + * @param field_name Field name + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_sub_query_set_field_name(zvec_sub_query_t *query, const char *field_name); + +/** + * @brief Get field name + * @param query Sub-query pointer + * @return const char* Field name (owned by query, do not free) + */ +ZVEC_EXPORT const char *ZVEC_CALL +zvec_sub_query_get_field_name(const zvec_sub_query_t *query); + +/** + * @brief Set query vector data + * @param query Sub-query pointer + * @param data Vector data pointer + * @param size Data size in bytes + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_sub_query_set_query_vector( + zvec_sub_query_t *query, const void *data, size_t size); + +/** + * @brief Set sparse vector indices + * @param query Sub-query pointer + * @param indices Array of uint32_t indices + * @param count Number of indices + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_sub_query_set_sparse_indices( + zvec_sub_query_t *query, const uint32_t *indices, size_t count); + +/** + * @brief Set sparse vector values + * @param query Sub-query pointer + * @param values Array of float values + * @param count Number of values + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_sub_query_set_sparse_values( + zvec_sub_query_t *query, const float *values, size_t count); + +/** + * @brief Set HNSW query parameters (takes ownership) + * @param query Sub-query pointer + * @param hnsw_params HNSW query parameters pointer + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_sub_query_set_hnsw_params( + zvec_sub_query_t *query, zvec_hnsw_query_params_t *hnsw_params); + +/** + * @brief Set IVF query parameters (takes ownership) + * @param query Sub-query pointer + * @param ivf_params IVF query parameters pointer + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_sub_query_set_ivf_params( + zvec_sub_query_t *query, zvec_ivf_query_params_t *ivf_params); + +/** + * @brief Set Flat query parameters (takes ownership) + * @param query Sub-query pointer + * @param flat_params Flat query parameters pointer + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_sub_query_set_flat_params( + zvec_sub_query_t *query, zvec_flat_query_params_t *flat_params); // ============================================================================= // Collection Options and Statistics (Opaque Pointer Pattern) // ============================================================================= @@ -2024,16 +2317,6 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_field_schema_validate( // Collection Schema Structures (Opaque Pointer Pattern) // ============================================================================= -/** - * @brief Collection schema handle (opaque pointer) - * - * Internally maps to zvec::CollectionSchema* (raw pointer). - * Created by zvec_collection_schema_create() and destroyed by - * zvec_collection_schema_destroy(). Caller owns the pointer and must explicitly - * destroy it. - */ -typedef struct zvec_collection_schema_t zvec_collection_schema_t; - /** * @brief Create collection schema * @param name Collection name @@ -2476,13 +2759,6 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_collection_alter_column( zvec_collection_t *collection, const char *column_name, const char *new_name, const zvec_field_schema_t *new_schema); -/** - * @brief Document structure (opaque pointer mode) - * Internal implementation details are not visible to the outside, and - * operations are performed through API functions - */ -typedef struct zvec_doc_t zvec_doc_t; - /** * @brief Per-document status returned by detailed DML APIs. * @note Uses ordered style: result index corresponds to input document index. @@ -2645,6 +2921,19 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_collection_query( const zvec_collection_t *collection, const zvec_vector_query_t *query, zvec_doc_t ***results, size_t *result_count); +/** + * @brief Multi-query with multiple sub-queries and re-ranking + * @param collection Collection handle + * @param query Multi-query query parameters pointer + * @param[out] results Returned document array (needs to be freed by calling + * zvec_docs_free) + * @param[out] result_count Number of returned results + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_collection_multi_query( + const zvec_collection_t *collection, const zvec_multi_query_t *query, + zvec_doc_t ***results, size_t *result_count); + /** * @brief Fetch documents by primary keys * @param collection Collection handle diff --git a/src/include/zvec/db/collection.h b/src/include/zvec/db/collection.h index 6dd4596b2..3df2840b7 100644 --- a/src/include/zvec/db/collection.h +++ b/src/include/zvec/db/collection.h @@ -17,8 +17,8 @@ #include #include #include -#include #include +#include #include #include @@ -99,6 +99,8 @@ class Collection { virtual Result Query(const VectorQuery &query) const = 0; + virtual Result Query(const MultiQuery &query) const = 0; + virtual Result GroupByQuery( const GroupByVectorQuery &query) const = 0; diff --git a/src/include/zvec/db/doc.h b/src/include/zvec/db/doc.h index f702a43c3..3dbe9a7c9 100644 --- a/src/include/zvec/db/doc.h +++ b/src/include/zvec/db/doc.h @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -364,44 +363,4 @@ using DocPtrMap = std::unordered_map; using WriteResults = std::vector; -struct VectorQuery { - int topk_; - std::string field_name_; - std::string query_vector_; // fp16, void * - std::string query_sparse_indices_; - std::string query_sparse_values_; - std::string filter_; - bool include_vector_{false}; - bool include_doc_id_{false}; - // select * by default, select no field if output_fields_ is empty, select - // specific fields if output_fields_ is not empty - std::optional> output_fields_; - QueryParams::Ptr query_params_; - - Status validate_and_sanitize(const FieldSchema *schema); -}; - -struct GroupByVectorQuery { - std::string field_name_; - std::string query_vector_; - std::string query_sparse_indices_; - std::string query_sparse_values_; - std::string filter_; - bool include_vector_; - // select * by default, select no field if output_fields_ is empty, select - // specific fields if output_fields_ is not empty - std::optional> output_fields_; - std::string group_by_field_name_; - uint32_t group_count_ = 2; - uint32_t group_topk_ = 3; - QueryParams::Ptr query_params_; -}; - -struct GroupResult { - std::string group_by_value_; - std::vector docs_; -}; - -using GroupResults = std::vector; - } // namespace zvec diff --git a/src/include/zvec/db/query.h b/src/include/zvec/db/query.h new file mode 100644 index 000000000..98e2c7395 --- /dev/null +++ b/src/include/zvec/db/query.h @@ -0,0 +1,102 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace zvec { + +struct VectorQuery { + int topk_; + std::string field_name_; + std::string query_vector_; // fp16, void * + std::string query_sparse_indices_; + std::string query_sparse_values_; + std::string filter_; + bool include_vector_{false}; + bool include_doc_id_{false}; + // select * by default, select no field if output_fields_ is empty, select + // specific fields if output_fields_ is not empty + std::optional> output_fields_; + QueryParams::Ptr query_params_; + + Status validate_and_sanitize(const FieldSchema *schema); +}; + +struct GroupByVectorQuery { + std::string field_name_; + std::string query_vector_; + std::string query_sparse_indices_; + std::string query_sparse_values_; + std::string filter_; + bool include_vector_; + // select * by default, select no field if output_fields_ is empty, select + // specific fields if output_fields_ is not empty + std::optional> output_fields_; + std::string group_by_field_name_; + uint32_t group_count_ = 2; + uint32_t group_topk_ = 3; + QueryParams::Ptr query_params_; +}; + +//! Multi query structure for combining multiple sub-queries +//! (vector, full-text, etc.) with optional re-ranking of results. + +struct VectorClause { + std::string query_vector_; + std::string sparse_indices_; + std::string sparse_values_; +}; + +struct FtsClause { + std::string query_string_; + std::string match_string_; +}; + +struct QueryTarget { + std::string field_name_; + std::variant clause_; + QueryParams::Ptr query_params_; +}; + +struct SubQuery { + QueryTarget target_; + int num_candidates_{10}; +}; + +struct MultiQuery { + std::vector queries; + int topk{10}; + std::string filter; + bool include_vector{false}; + bool include_doc_id_{false}; + std::optional> output_fields; + std::shared_ptr reranker{nullptr}; +}; + +struct GroupResult { + std::string group_by_value_; + std::vector docs_; +}; + +using GroupResults = std::vector; + +} // namespace zvec diff --git a/src/include/zvec/db/reranker.h b/src/include/zvec/db/reranker.h new file mode 100644 index 000000000..0b8b56ef8 --- /dev/null +++ b/src/include/zvec/db/reranker.h @@ -0,0 +1,135 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "zvec/db/status.h" + +namespace zvec { + +//! Reranker abstract base class for re-ranking search results +class Reranker { + public: + using Ptr = std::shared_ptr; + + Reranker() = default; + virtual ~Reranker() = default; + + virtual void bind_schema(CollectionSchema::Ptr) {} + + //! Re-rank documents from one or more vector queries. + //! \param query_results Mapping from vector field name to list of retrieved + //! documents (sorted by relevance). + //! \param topn Maximum number of documents to return. + //! \return Re-ranked list of documents (length <= topn), with updated scores. + virtual Result rerank( + const std::map &query_results, + int topn = 10) const = 0; +}; + +//! Intermediate base for rerankers that compute per-document scores. +//! +//! Implements the common rerank() logic: iterate docs, call rescore() for each, +//! accumulate scores by doc_id, and return topn results in descending order. +//! Subclasses only need to implement rescore(). +class ScoreBasedReranker : public Reranker { + public: + //! Compute the contribution score for a single document. + //! \param score The document's raw relevance score from the vector field. + //! \param rank The document's position (0-based) in the per-field result + //! list. \param field_name The name of the vector field this result came + //! from. \return The score contribution to be accumulated for this document. + virtual Result rescore(double score, int rank, + const std::string &field_name) const = 0; + + Result rerank( + const std::map &query_results, + int topn = 10) const override; +}; + +//! Re-ranker using Reciprocal Rank Fusion (RRF) for multi-vector search. +//! +//! RRF combines results from multiple vector queries without requiring +//! relevance scores. The RRF score for a document at rank r is: +//! score = 1 / (k + r + 1) +//! where k is the rank constant. +class RrfReranker : public ScoreBasedReranker { + public: + explicit RrfReranker(int rank_constant = 60) + : rank_constant_(rank_constant) {} + + int rank_constant() const { + return rank_constant_; + } + + Result rescore(double score, int rank, + const std::string &field_name) const override; + + private: + int rank_constant_; +}; + +//! Re-ranker that combines scores from multiple vector fields using weights. +//! +//! Each vector field's relevance score is normalized based on its own metric +//! type, then scaled by a user-provided weight. Final scores are summed across +//! fields. Supported metrics: L2, IP, COSINE. +class WeightedReranker : public ScoreBasedReranker { + public: + explicit WeightedReranker(const std::map &weights = {}); + + void bind_schema(CollectionSchema::Ptr schema) override; + + const std::map &weights() const { + return weights_; + } + + Result rescore(double score, int rank, + const std::string &field_name) const override; + + private: + static Result normalize_score(double score, const FieldSchema &field); + + CollectionSchema::Ptr schema_; + std::map weights_; +}; + +//! Callback-based re-ranker for cross-language bridging. +//! +//! Wraps a user-provided callback (e.g., a Python callable) as a Reranker. +//! When the callback is a Python function, GIL must be managed by the caller. +class CallbackReranker : public Reranker { + public: + using Callback = + std::function &, int)>; + + explicit CallbackReranker(Callback fn) : callback_(std::move(fn)) {} + + Result rerank( + const std::map &query_results, + int topn = 10) const override { + return callback_(query_results, topn); + } + + private: + Callback callback_; +}; + +} // namespace zvec diff --git a/tests/c/c_api_test.c b/tests/c/c_api_test.c index 846cc548c..c9ba891f7 100644 --- a/tests/c/c_api_test.c +++ b/tests/c/c_api_test.c @@ -4127,6 +4127,231 @@ void test_actual_vector_queries(void) { TEST_END(); } +void test_reranker_functions(void) { + TEST_START(); + + // Test 1: Create RRF reranker + zvec_reranker_t *rrf = zvec_reranker_create_rrf(60); + TEST_ASSERT(rrf != NULL); + if (rrf) { + TEST_ASSERT(zvec_reranker_get_rank_constant(rrf) == 60); + zvec_reranker_destroy(rrf); + } + + // Test 2: Create RRF reranker with different rank constant + zvec_reranker_t *rrf2 = zvec_reranker_create_rrf(100); + TEST_ASSERT(rrf2 != NULL); + if (rrf2) { + TEST_ASSERT(zvec_reranker_get_rank_constant(rrf2) == 100); + zvec_reranker_destroy(rrf2); + } + + // Test 3: Create Weighted reranker + const char *fields[] = {"embedding1", "embedding2"}; + double weights[] = {0.7, 0.3}; + zvec_reranker_t *weighted = zvec_reranker_create_weighted(fields, weights, 2); + TEST_ASSERT(weighted != NULL); + if (weighted) { + TEST_ASSERT(zvec_reranker_get_rank_constant(weighted) == -1); + zvec_reranker_destroy(weighted); + } + + // Test 4: Create Weighted reranker with no fields + zvec_reranker_t *weighted2 = zvec_reranker_create_weighted(NULL, NULL, 0); + TEST_ASSERT(weighted2 != NULL); + if (weighted2) { + zvec_reranker_destroy(weighted2); + } + + // Test 5: NULL reranker operations + TEST_ASSERT(zvec_reranker_get_rank_constant(NULL) == -1); + zvec_reranker_destroy(NULL); // Should not crash + + TEST_END(); +} + +// ==================== Multi-query reranker test helpers ==================== + +typedef struct { + zvec_collection_t *collection; + zvec_collection_schema_t *schema; + zvec_doc_t *docs[4]; + float e1_v1[4], e2_v1[4]; + char temp_dir[64]; +} multi_query_fixture_t; + +static int setup_multi_query_fixture(multi_query_fixture_t *f, + const char *dir_name, + const char *schema_name) { + snprintf(f->temp_dir, sizeof(f->temp_dir), "./%s", dir_name); + f->collection = NULL; + f->schema = zvec_collection_schema_create(schema_name); + if (!f->schema) return 0; + + zvec_field_schema_t *id_field = + zvec_field_schema_create("id", ZVEC_DATA_TYPE_INT64, false, 0); + zvec_collection_schema_add_field(f->schema, id_field); + + for (int i = 0; i < 2; i++) { + const char *name = i == 0 ? "embedding1" : "embedding2"; + zvec_index_params_t *hnsw = zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); + zvec_index_params_set_metric_type(hnsw, ZVEC_METRIC_TYPE_L2); + zvec_index_params_set_hnsw_params(hnsw, 16, 100); + zvec_field_schema_t *vec = + zvec_field_schema_create(name, ZVEC_DATA_TYPE_VECTOR_FP32, false, 4); + zvec_field_schema_set_index_params(vec, hnsw); + zvec_collection_schema_add_field(f->schema, vec); + zvec_index_params_destroy(hnsw); + } + + zvec_error_code_t err = zvec_collection_create_and_open( + f->temp_dir, f->schema, NULL, &f->collection); + if (err != ZVEC_OK || !f->collection) return 0; + + float e1[4][4] = {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {.7f, .7f, 0, 0}}; + float e2[4][4] = {{0, 1, 0, 0}, {1, 0, 0, 0}, {0, 0, 0, 1}, {.5f, .5f, 0, 0}}; + memcpy(f->e1_v1, e1[0], sizeof(f->e1_v1)); + memcpy(f->e2_v1, e2[0], sizeof(f->e2_v1)); + + for (int i = 0; i < 4; i++) { + f->docs[i] = zvec_doc_create(); + zvec_doc_set_pk(f->docs[i], zvec_test_make_pk(i + 1)); + zvec_doc_add_field_by_value(f->docs[i], "id", ZVEC_DATA_TYPE_INT64, + &(int64_t){i + 1}, sizeof(int64_t)); + zvec_doc_add_field_by_value(f->docs[i], "embedding1", + ZVEC_DATA_TYPE_VECTOR_FP32, e1[i], + sizeof(e1[i])); + zvec_doc_add_field_by_value(f->docs[i], "embedding2", + ZVEC_DATA_TYPE_VECTOR_FP32, e2[i], + sizeof(e2[i])); + } + + size_t success_count, error_count; + err = zvec_collection_insert(f->collection, (const zvec_doc_t **)f->docs, 4, + &success_count, &error_count); + if (err != ZVEC_OK || success_count != 4) return 0; + + zvec_collection_flush(f->collection); + return 1; +} + +static void teardown_multi_query_fixture(multi_query_fixture_t *f) { + for (int i = 0; i < 4; i++) zvec_doc_destroy(f->docs[i]); + zvec_collection_destroy(f->collection); + zvec_collection_schema_destroy(f->schema); + cleanup_temp_directory(f->temp_dir); +} + +static int execute_multi_query_with_reranker(const multi_query_fixture_t *f, + zvec_reranker_t *reranker, + int topk, int num_candidates) { + zvec_multi_query_t *mvq = zvec_multi_query_create(); + if (!mvq) return -1; + zvec_multi_query_set_topk(mvq, topk); + zvec_multi_query_set_include_vector(mvq, false); + + zvec_sub_query_t *vq1 = zvec_sub_query_create(); + zvec_sub_query_set_field_name(vq1, "embedding1"); + zvec_sub_query_set_query_vector(vq1, f->e1_v1, sizeof(f->e1_v1)); + zvec_sub_query_set_num_candidates(vq1, num_candidates); + zvec_multi_query_add_sub_query(mvq, vq1); + + zvec_sub_query_t *vq2 = zvec_sub_query_create(); + zvec_sub_query_set_field_name(vq2, "embedding2"); + zvec_sub_query_set_query_vector(vq2, f->e2_v1, sizeof(f->e2_v1)); + zvec_sub_query_set_num_candidates(vq2, num_candidates); + zvec_multi_query_add_sub_query(mvq, vq2); + + zvec_multi_query_set_reranker(mvq, reranker); + + zvec_doc_t **results = NULL; + size_t result_count = 0; + zvec_error_code_t err = + zvec_collection_multi_query(f->collection, mvq, &results, &result_count); + + int ret = -1; + if (err == ZVEC_OK && results != NULL) { + ret = (int)result_count; + zvec_docs_free(results, result_count); + } + + zvec_sub_query_destroy(vq1); + zvec_sub_query_destroy(vq2); + zvec_multi_query_destroy(mvq); + return ret; +} + +// ==================== Multi-query reranker tests ==================== + +void test_multi_vector_query_with_rrf_reranker(void) { + TEST_START(); + + multi_query_fixture_t f; + TEST_ASSERT(setup_multi_query_fixture(&f, "zvec_test_mq_rrf", "mq_rrf")); + + zvec_reranker_t *rrf = zvec_reranker_create_rrf(60); + TEST_ASSERT(rrf != NULL); + + int count = execute_multi_query_with_reranker(&f, rrf, 3, 3); + TEST_ASSERT(count > 0); + TEST_ASSERT(count <= 3); + + zvec_reranker_destroy(rrf); + + // MultiQuery property setters/getters + zvec_multi_query_t *mvq2 = zvec_multi_query_create(); + TEST_ASSERT(mvq2 != NULL); + zvec_multi_query_set_topk(mvq2, 5); + TEST_ASSERT(zvec_multi_query_get_topk(mvq2) == 5); + + zvec_multi_query_set_filter(mvq2, "id > 1"); + TEST_ASSERT(strcmp(zvec_multi_query_get_filter(mvq2), "id > 1") == 0); + + zvec_multi_query_set_include_vector(mvq2, true); + TEST_ASSERT(zvec_multi_query_get_include_vector(mvq2) == true); + + const char *out_fields[] = {"id"}; + zvec_multi_query_set_output_fields(mvq2, out_fields, 1); + const char **got_fields = NULL; + size_t field_count = 0; + zvec_error_code_t err = + zvec_multi_query_get_output_fields(mvq2, &got_fields, &field_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(field_count == 1); + if (field_count > 0) { + TEST_ASSERT(strcmp(got_fields[0], "id") == 0); + zvec_free((char *)got_fields); + } + + zvec_multi_query_destroy(mvq2); + + teardown_multi_query_fixture(&f); + + TEST_END(); +} + +void test_multi_vector_query_with_weighted_reranker(void) { + TEST_START(); + + multi_query_fixture_t f; + TEST_ASSERT( + setup_multi_query_fixture(&f, "zvec_test_mq_weighted", "mq_weighted")); + + const char *fields[] = {"embedding1", "embedding2"}; + double weights[] = {0.7, 0.3}; + zvec_reranker_t *weighted = zvec_reranker_create_weighted(fields, weights, 2); + TEST_ASSERT(weighted != NULL); + + int count = execute_multi_query_with_reranker(&f, weighted, 3, 3); + TEST_ASSERT(count > 0); + TEST_ASSERT(count <= 3); + + zvec_reranker_destroy(weighted); + teardown_multi_query_fixture(&f); + + TEST_END(); +} + void test_index_creation_and_management(void) { TEST_START(); @@ -5448,7 +5673,9 @@ int main(void) { // Query tests test_query_params_functions(); test_actual_vector_queries(); - + test_reranker_functions(); + test_multi_vector_query_with_rrf_reranker(); + test_multi_vector_query_with_weighted_reranker(); // Performance tests // test_performance_benchmarks(); diff --git a/tests/db/collection_test.cc b/tests/db/collection_test.cc index 9e2adfbbb..931740155 100644 --- a/tests/db/collection_test.cc +++ b/tests/db/collection_test.cc @@ -34,6 +34,7 @@ #include "zvec/db/doc.h" #include "zvec/db/index_params.h" #include "zvec/db/options.h" +#include "zvec/db/reranker.h" #include "zvec/db/schema.h" #include "zvec/db/status.h" #include "zvec/db/type.h" @@ -93,7 +94,8 @@ TEST_F(CollectionTest, Feature_CreateAndOpen_General) { ASSERT_FALSE(col->Delete({}).has_value()); ASSERT_FALSE(col->DeleteByFilter("").ok()); ASSERT_FALSE(col->Fetch({}).has_value()); - ASSERT_FALSE(col->Query({}).has_value()); + ASSERT_FALSE(col->Query(VectorQuery{}).has_value()); + ASSERT_FALSE(col->Query(MultiQuery{}).has_value()); ASSERT_FALSE(col->GroupByQuery({}).has_value()); ASSERT_FALSE(col->CreateIndex("", nullptr).ok()); ASSERT_FALSE(col->DropIndex("").ok()); @@ -3677,6 +3679,438 @@ TEST_F(CollectionTest, Feature_Query_WithoutVector_WithScalarIndex) { "array_int32 contain_any (1)", 1); } +// ============================================================================= +// MultiQuery Tests +// ============================================================================= + +TEST_F(CollectionTest, Feature_MultiQuery_Validate) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + // Test 1: Empty queries should fail + { + MultiQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(60); + auto result = collection->Query(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } + + // Test 2: No reranker with multiple queries should fail + { + MultiQuery mvq; + mvq.topk = 10; + auto query_doc = TestHelper::CreateDoc(1, *schema); + + SubQuery vq1; + vq1.num_candidates_ = 10; + vq1.target_.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + std::get(vq1.target_.clause_) + .query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq1); + + SubQuery vq2; + vq2.num_candidates_ = 10; + vq2.target_.field_name_ = "dense_fp16"; + auto vector2 = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector2.has_value()); + std::get(vq2.target_.clause_) + .query_vector_.assign((char *)vector2.value().data(), + vector2.value().size() * sizeof(float)); + mvq.queries.push_back(vq2); + + auto result = collection->Query(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } + + // Test 3: Invalid field name should fail + { + MultiQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(60); + + SubQuery vq1; + vq1.num_candidates_ = 10; + vq1.target_.field_name_ = "nonexistent_field"; + std::get(vq1.target_.clause_) + .query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq1); + + SubQuery vq2; + vq2.num_candidates_ = 10; + vq2.target_.field_name_ = "dense_fp32"; + std::get(vq2.target_.clause_) + .query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq2); + + auto result = collection->Query(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } + + // Test 4: Duplicate field names should fail + { + MultiQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(60); + + SubQuery vq1; + vq1.num_candidates_ = 10; + vq1.target_.field_name_ = "dense_fp32"; + std::get(vq1.target_.clause_) + .query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq1); + + SubQuery vq2; + vq2.num_candidates_ = 10; + vq2.target_.field_name_ = "dense_fp32"; + std::get(vq2.target_.clause_) + .query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq2); + + auto result = collection->Query(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } +} + +TEST_F(CollectionTest, Feature_MultiQuery_SingleFieldWithReranker) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + // Single query with reranker should fail (requires at least 2 sub-queries) + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(60); + + SubQuery vq; + vq.num_candidates_ = 10; + vq.target_.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + std::get(vq.target_.clause_) + .query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq); + + auto result = collection->Query(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); +} + +TEST_F(CollectionTest, Feature_MultiQuery_MultiFieldRRF) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(60); + + // Query dense_fp32 and dense_fp16 fields with different vectors + auto vector1 = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector1.has_value()); + + { + SubQuery vq; + vq.num_candidates_ = 10; + vq.target_.field_name_ = "dense_fp32"; + std::get(vq.target_.clause_) + .query_vector_.assign((char *)vector1.value().data(), + vector1.value().size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + // Query sparse_fp32 field + auto sparse = + query_doc.get, std::vector>>( + "sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + + { + SubQuery vq; + vq.num_candidates_ = 10; + vq.target_.field_name_ = "sparse_fp32"; + std::get(vq.target_.clause_) + .sparse_indices_.assign((char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + std::get(vq.target_.clause_) + .sparse_values_.assign((char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + auto result = collection->Query(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 10u); + + // All results should have valid scores (RRF fused) + for (const auto &doc : result.value()) { + EXPECT_NE(doc->score(), 0.0f); + } +} + +TEST_F(CollectionTest, Feature_MultiQuery_MultiFieldWeighted) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiQuery mvq; + mvq.topk = 10; + std::map weights = {{"dense_fp32", 0.7}, + {"sparse_fp32", 0.3}}; + mvq.reranker = std::make_shared(weights); + + // Query dense_fp32 field + { + SubQuery vq; + vq.num_candidates_ = 10; + vq.target_.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + std::get(vq.target_.clause_) + .query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + // Query sparse_fp32 field + { + SubQuery vq; + vq.num_candidates_ = 10; + vq.target_.field_name_ = "sparse_fp32"; + auto sparse = + query_doc.get, std::vector>>( + "sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + std::get(vq.target_.clause_) + .sparse_indices_.assign((char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + std::get(vq.target_.clause_) + .sparse_values_.assign((char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + auto result = collection->Query(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 10u); +} + +TEST_F(CollectionTest, Feature_MultiQuery_WithFilter) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiQuery mvq; + mvq.topk = 10; + mvq.filter = "int32 > 50"; + mvq.reranker = std::make_shared(60); + + SubQuery vq1; + vq1.num_candidates_ = 10; + vq1.target_.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + std::get(vq1.target_.clause_) + .query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq1); + + auto sparse = + query_doc.get, std::vector>>( + "sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + SubQuery vq2; + vq2.num_candidates_ = 10; + vq2.target_.field_name_ = "sparse_fp32"; + std::get(vq2.target_.clause_) + .sparse_indices_.assign((char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + std::get(vq2.target_.clause_) + .sparse_values_.assign((char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq2); + + auto result = collection->Query(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 10u); +} + +TEST_F(CollectionTest, Feature_MultiQuery_WithOutputFields) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiQuery mvq; + mvq.topk = 5; + mvq.include_vector = false; + mvq.output_fields = std::make_optional>( + std::vector{"int32", "string"}); + mvq.reranker = std::make_shared(60); + + SubQuery vq1; + vq1.num_candidates_ = 10; + vq1.target_.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + std::get(vq1.target_.clause_) + .query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq1); + + auto sparse = + query_doc.get, std::vector>>( + "sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + SubQuery vq2; + vq2.num_candidates_ = 10; + vq2.target_.field_name_ = "sparse_fp32"; + std::get(vq2.target_.clause_) + .sparse_indices_.assign((char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + std::get(vq2.target_.clause_) + .sparse_values_.assign((char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq2); + + auto result = collection->Query(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 5u); +} + +TEST_F(CollectionTest, Feature_MultiQuery_CallbackReranker) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + // Use CallbackReranker with a lambda that merges and sorts by score + bool callback_invoked = false; + auto callback_fn = [&callback_invoked]( + const std::map &query_results, + int topn) -> DocPtrList { + callback_invoked = true; + DocPtrList all_docs; + for (const auto &[_, docs] : query_results) { + for (const auto &doc : docs) { + all_docs.push_back(doc); + } + } + std::sort(all_docs.begin(), all_docs.end(), + [](const Doc::Ptr &a, const Doc::Ptr &b) { + return a->score() > b->score(); + }); + if (static_cast(all_docs.size()) > topn) { + all_docs.resize(topn); + } + return all_docs; + }; + + MultiQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(callback_fn); + + // Query dense_fp32 field + { + SubQuery vq; + vq.num_candidates_ = 10; + vq.target_.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + std::get(vq.target_.clause_) + .query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + // Query sparse_fp32 field + { + SubQuery vq; + vq.num_candidates_ = 10; + vq.target_.field_name_ = "sparse_fp32"; + auto sparse = + query_doc.get, std::vector>>( + "sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + std::get(vq.target_.clause_) + .sparse_indices_.assign((char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + std::get(vq.target_.clause_) + .sparse_values_.assign((char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + auto result = collection->Query(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_TRUE(callback_invoked); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 10u); + + // Verify results are sorted by score descending + for (size_t i = 1; i < result.value().size(); ++i) { + EXPECT_GE(result.value()[i - 1]->score(), result.value()[i]->score()); + } +} + TEST_F(CollectionTest, Feature_GroupByQuery) {} TEST_F(CollectionTest, Feature_AddColumn_General) { diff --git a/tests/db/reranker_test.cc b/tests/db/reranker_test.cc new file mode 100644 index 000000000..adbd3fd53 --- /dev/null +++ b/tests/db/reranker_test.cc @@ -0,0 +1,283 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace zvec; + +namespace { + +Doc::Ptr MakeDoc(const std::string &id, float score) { + auto doc = std::make_shared(); + doc->set_pk(id); + doc->set_score(score); + return doc; +} + +CollectionSchema::Ptr MakeSchema( + const std::vector> &fields) { + auto schema = std::make_shared("test"); + for (const auto &[name, metric] : fields) { + auto field = std::make_shared( + name, DataType::VECTOR_FP16, /*dimension=*/4, /*nullable=*/false, + std::make_shared(metric)); + schema->add_field(field); + } + return schema; +} + +} // namespace + +// ==================== RrfReranker Tests ==================== + +TEST(RrfRerankerTest, BasicRRF) { + RrfReranker reranker(/*rank_constant=*/60); + + // Two vector fields, each returning 3 documents with some overlap + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f), + MakeDoc("c", 0.7f)}; + query_results["vec2"] = {MakeDoc("b", 0.95f), MakeDoc("a", 0.85f), + MakeDoc("d", 0.75f)}; + + auto result = reranker.rerank(query_results, /*topn=*/10); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + + // "a" appears at rank 0 in vec1 and rank 1 in vec2: + // rrf_score = 1/(60+0+1) + 1/(60+1+1) = 1/61 + 1/62 + // "b" appears at rank 1 in vec1 and rank 0 in vec2: + // rrf_score = 1/(60+1+1) + 1/(60+0+1) = 1/62 + 1/61 + // So a and b should have equal scores and be at the top + ASSERT_GE(results.size(), 3u); + + // "a" and "b" should have the highest RRF scores + EXPECT_EQ(results[0]->pk(), "a"); + EXPECT_EQ(results[1]->pk(), "b"); + // Verify scores are close (a and b have same RRF score) + EXPECT_NEAR(results[0]->score(), results[1]->score(), 1e-10); +} + +TEST(RrfRerankerTest, Topn) { + RrfReranker reranker(/*rank_constant=*/60); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f), + MakeDoc("c", 0.7f)}; + + auto result = reranker.rerank(query_results, /*topn=*/2); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value().size(), 2u); +} + +TEST(RrfRerankerTest, SingleField) { + RrfReranker reranker(/*rank_constant=*/60); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f)}; + + auto result = reranker.rerank(query_results); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + ASSERT_EQ(results.size(), 2u); + // With single field, RRF score for rank 0 > rank 1 + EXPECT_GT(results[0]->score(), results[1]->score()); +} + +TEST(RrfRerankerTest, EmptyResults) { + RrfReranker reranker(/*rank_constant=*/60); + + std::map query_results; + auto result = reranker.rerank(query_results); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(result.value().empty()); +} + +// ==================== WeightedReranker Tests ==================== + +TEST(WeightedRerankerTest, BasicWeighted) { + auto schema = + MakeSchema({{"vec1", MetricType::L2}, {"vec2", MetricType::L2}}); + WeightedReranker reranker({{"vec1", 0.7}, {"vec2", 0.3}}); + reranker.bind_schema(schema); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.5f), MakeDoc("b", 0.3f)}; + query_results["vec2"] = {MakeDoc("a", 0.8f), MakeDoc("c", 0.6f)}; + + auto result = reranker.rerank(query_results); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + ASSERT_GE(results.size(), 2u); + // "a" appears in both fields, should have highest combined score + EXPECT_EQ(results[0]->pk(), "a"); +} + +TEST(WeightedRerankerTest, MixedMetrics) { + auto schema = + MakeSchema({{"vec1", MetricType::L2}, {"vec2", MetricType::COSINE}}); + WeightedReranker reranker({{"vec1", 0.5}, {"vec2", 0.5}}); + reranker.bind_schema(schema); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.5f)}; + query_results["vec2"] = {MakeDoc("a", 0.4f)}; + + auto result = reranker.rerank(query_results); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0]->pk(), "a"); + // L2 normalize(0.5) = 1 - 2*atan(0.5)/pi ≈ 0.7048 + // COSINE normalize(0.4) = 1 - 0.4/2 = 0.8 + // weighted = 0.7048 * 0.5 + 0.8 * 0.5 ≈ 0.7524 + double l2_norm = 1.0 - 2.0 * std::atan(0.5) / M_PI; + double cos_norm = 1.0 - 0.4 / 2.0; + double expected = l2_norm * 0.5 + cos_norm * 0.5; + EXPECT_NEAR(results[0]->score(), expected, 1e-5); +} + +TEST(WeightedRerankerTest, MissingMetricError) { + auto schema = MakeSchema({{"vec1", MetricType::L2}}); + WeightedReranker reranker; + reranker.bind_schema(schema); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.5f)}; + query_results["vec2"] = {MakeDoc("b", 0.3f)}; + + auto result = reranker.rerank(query_results); + ASSERT_FALSE(result.has_value()); +} + +TEST(WeightedRerankerTest, NormalizeL2) { + auto schema = MakeSchema({{"vec1", MetricType::L2}}); + WeightedReranker reranker; + reranker.bind_schema(schema); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.0f), MakeDoc("b", 1.0f)}; + + auto result = reranker.rerank(query_results); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + ASSERT_EQ(results.size(), 2u); + // L2 normalize(0.0) = 1.0, normalize(1.0) ∈ (0, 1) + EXPECT_NEAR(results[0]->score(), 1.0, 1e-10); + EXPECT_EQ(results[0]->pk(), "a"); + EXPECT_GT(results[1]->score(), 0.0); + EXPECT_LT(results[1]->score(), 1.0); +} + +TEST(WeightedRerankerTest, NormalizeIP) { + auto schema = MakeSchema({{"vec1", MetricType::IP}}); + WeightedReranker reranker; + reranker.bind_schema(schema); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.0f), MakeDoc("b", 1.0f)}; + + auto result = reranker.rerank(query_results); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + ASSERT_EQ(results.size(), 2u); + // IP normalize(1.0) > 0.5 > normalize(0.0) = 0.5... but b scores higher + EXPECT_EQ(results[0]->pk(), "b"); + EXPECT_GT(results[0]->score(), 0.5); + EXPECT_NEAR(results[1]->score(), 0.5, 1e-10); +} + +TEST(WeightedRerankerTest, NormalizeCosine) { + auto schema = MakeSchema({{"vec1", MetricType::COSINE}}); + WeightedReranker reranker; + reranker.bind_schema(schema); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.0f), MakeDoc("b", 1.0f), + MakeDoc("c", 2.0f)}; + + auto result = reranker.rerank(query_results); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + ASSERT_EQ(results.size(), 3u); + // COSINE normalize(0.0) = 1.0, normalize(1.0) = 0.5, normalize(2.0) = 0.0 + EXPECT_NEAR(results[0]->score(), 1.0, 1e-10); + EXPECT_NEAR(results[1]->score(), 0.5, 1e-10); + EXPECT_NEAR(results[2]->score(), 0.0, 1e-10); +} + +TEST(WeightedRerankerTest, Topn) { + auto schema = MakeSchema({{"vec1", MetricType::L2}}); + WeightedReranker reranker; + reranker.bind_schema(schema); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.1f), MakeDoc("b", 0.2f), + MakeDoc("c", 0.3f)}; + + auto result = reranker.rerank(query_results, /*topn=*/2); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value().size(), 2u); +} + + +// ==================== CallbackReranker Tests ==================== + +TEST(CallbackRerankerTest, BasicCallback) { + // Simple callback that returns docs sorted by score descending, limited to + // topn + CallbackReranker::Callback cb = + [](const std::map &query_results, + int topn) -> DocPtrList { + DocPtrList all_docs; + for (const auto &[_, docs] : query_results) { + for (const auto &doc : docs) { + all_docs.push_back(doc); + } + } + std::sort(all_docs.begin(), all_docs.end(), + [](const Doc::Ptr &a, const Doc::Ptr &b) { + return a->score() > b->score(); + }); + if (static_cast(all_docs.size()) > topn) { + all_docs.resize(topn); + } + return all_docs; + }; + + CallbackReranker reranker(cb); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.5f), MakeDoc("b", 0.9f)}; + query_results["vec2"] = {MakeDoc("c", 0.7f)}; + + auto result = reranker.rerank(query_results, /*topn=*/10); + ASSERT_TRUE(result.has_value()); + auto &results = result.value(); + ASSERT_EQ(results.size(), 3u); + // Should be sorted by score descending + EXPECT_EQ(results[0]->pk(), "b"); + EXPECT_EQ(results[1]->pk(), "c"); + EXPECT_EQ(results[2]->pk(), "a"); +}