Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/inte/global_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,15 @@ TEST_P(GlobalIndexTest, TestLuceneWriteCommitScanReadIndexWithScore) {
/*pre_filter=*/std::nullopt)));
ASSERT_TRUE(index_result->ToString().find("row ids: {0,1,2}") != std::string::npos);
}
{
std::optional<RoaringBitmap64> pre_filter = RoaringBitmap64::From({1, 2, 3});
ASSERT_OK_AND_ASSIGN(
auto index_result,
index_reader->VisitFullTextSearch(std::make_shared<FullTextSearch>(
"f0",
/*limit=*/10, "document", FullTextSearch::SearchType::MATCH_ALL, pre_filter)));
ASSERT_TRUE(index_result->ToString().find("row ids: {1,2}") != std::string::npos);
}
{
ASSERT_OK_AND_ASSIGN(auto index_result,
index_reader->VisitFullTextSearch(std::make_shared<FullTextSearch>(
Expand Down
107 changes: 0 additions & 107 deletions third_party/lumina/OptionsReference.md

This file was deleted.

4 changes: 2 additions & 2 deletions third_party/lumina/VERSION
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
tag: v0.1.0
bd7ac880cf34af4d267fcaf16773627cad122463
tag: v0.2.1
c88ce90ed44b7037e3a307a36627cbd030e5eb60
1 change: 0 additions & 1 deletion third_party/lumina/include/lumina/api/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/


#pragma once

#include <cstdint>
Expand Down
3 changes: 3 additions & 0 deletions third_party/lumina/include/lumina/core/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <cstdint>
#include <string_view>
namespace lumina::core {

Expand Down Expand Up @@ -56,6 +57,7 @@ constexpr std::string_view kEncodingType = "encoding.type"; // Encoding type
constexpr std::string_view kEncodingRawf32 = "rawf32";
constexpr std::string_view kEncodingSQ8 = "sq8";
constexpr std::string_view kEncodingPQ = "pq";
constexpr std::string_view kEncodingRabitQ = "rabitq";
constexpr std::string_view kEncodingDummy = "dummy";

// IO options
Expand Down Expand Up @@ -88,6 +90,7 @@ constexpr std::string_view kExtensionPrefix = "extension.";
constexpr std::string_view kExtensionSearchWithFilter = "extension.search_with_filter";
constexpr std::string_view kExtensionCkptThreshold = "extension.build.ckpt.threshold";
constexpr std::string_view kExtensionCkptCount = "extension.build.ckpt.count";
constexpr std::string_view kExtensionGetVector = "extension.search.get_vector";

/* constexpr std::string_view kExtensionFilterDsl = "filter.dsl"; */
/* constexpr std::string_view kExtensionFilterTags = "filter.tags"; */
Expand Down
152 changes: 151 additions & 1 deletion third_party/lumina/include/lumina/distance/EncodedDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,74 @@

#pragma once

#include <cstddef>
#include <cstdint>
#include <lumina/distance/Metric.h>
#include <lumina/distance/MetricDistance.h>
#include <lumina/distance/encode_space/EncodedRowSource.h>
#include <lumina/distance/encode_space/EncodingTypes.h>
#include <span>
#include <type_traits>
#include <utility>

namespace lumina::dist {

// Encoded Distance Interface (CPO)
//
// This file defines the core abstraction for computing distances between a floating-point query
// and encoded (compressed/quantized) vectors.
//
// Workflow Overview:
// 1. Prepare: One-time processing of the query for a specific encoding (e.g., building
// lookup tables for PQ or pre-calculating distance components). This amortizes the cost
// of complex setup across many distance evaluations.
// 2. Evaluate: Use the prepared state to compute distances for many rows.
// - BatchEval: Best for contiguous scans (dense layout).
// - GatherEval: Best for graph/IVF traversal where candidate rows are scattered.
//
// Design Goals:
// - Performance: Enables SIMD/ISA-specific optimizations (like AVX-512 gather) via TagInvoke.
// - Decoupling: Uses `EncodedRowSource` to abstract memory layout (flat, strided, or aux-data)
// away from the distance logic.
// - Stability: The API remains constant even if the underlying encoding format changes.

struct PrepareTag {
/**
* @brief Transforms a raw query into a "Prepared State" optimized for a specific encoding.
*
* Why use this?
* For many encodings (like Product Quantization), calculating a distance requires expensive
* per-query setup (e.g., pre-computing 256 distances to codebook centroids). 'Prepare'
* ensures this work is done only once per search.
*
* @param m The Metric (e.g., MetricL2, MetricIP).
* @param e The Encoding (e.g., SQ8, PQ).
* @param q The raw floating-point query vector.
* @param ctx Additional encoding-specific resources (e.g., codebooks, headers).
* @return An opaque state object (S) to be passed to Eval/Gather functions.
*/
template <class M, class E, class... Ctx>
requires TagInvocable<PrepareTag, M, E, std::span<const float>, Ctx&&...>
constexpr auto operator()(const M& m, const E& e, std::span<const float> q, Ctx&&... ctx) const
noexcept(noexcept(TagInvoke(std::declval<PrepareTag>(), m, e, q, std::forward<Ctx>(ctx)...)))
-> TagInvokeResult<PrepareTag, M, E, std::span<const float>, Ctx&&...>
{
return TagInvoke(*this, m, e, q, ctx...);
return TagInvoke(*this, m, e, q, std::forward<Ctx>(ctx)...);
}
};
inline constexpr PrepareTag Prepare {};

struct EvalEncodedTag {
/**
* @brief Computes distance for a single encoded row.
*
* Typically used within a loop when manual flow control is needed. However, for
* performance-critical code, BatchEval or GatherEval are usually preferred as they
* allow the implementation to use SIMD more effectively.
*
* @param s The prepared state from Prepare().
* @param r A view to a single encoded row (e.g., encode_space::EncodedRow).
*/
template <class M, class E, class S, class R>
requires TagInvocable<EvalEncodedTag, M, E, const S&, const R&>
constexpr auto operator()(const M& m, const E& e, const S& s, const R& r) const
Expand All @@ -49,6 +96,15 @@ struct EvalEncodedTag {
inline constexpr EvalEncodedTag EvalEncoded {};

struct BatchEvalEncodedTag {
/**
* @brief Computes distances for a contiguous block of encoded rows.
*
* Optimized for exhaustive scans. The implementation can assume rows are adjacent in memory,
* allowing for efficient linear prefetching and unrolled SIMD processing.
*
* @param b Description of the contiguous batch (base pointer + count).
* @param out Pointer to output float array (must hold at least b.n elements).
*/
template <class M, class E, class S>
requires TagInvocable<BatchEvalEncodedTag, M, E, const S&, const encode_space::EncodedBatch&, float*>
constexpr void operator()(const M& m, const E& e, const S& s, const encode_space::EncodedBatch& b, float* out) const
Expand All @@ -59,4 +115,98 @@ struct BatchEvalEncodedTag {
};
inline constexpr BatchEvalEncodedTag BatchEvalEncoded {};

struct GatherEvalEncodedTag {
/**
* @brief Computes distances for a non-contiguous set of row IDs.
*
* The primary workhorse for graph-based or IVF search. Given a list of candidate IDs,
* this function "gathers" the encoded data and computes distances.
*
* Performance Note: Specializations of this tag often use ISA-specific instructions
* (e.g., VPGATHERDD) or software pipelining to hide memory latency during random access.
*
* @param data The data source (models EncodedRowSource) providing mapping from ID to memory.
* @param rowIds The list of logical row IDs to evaluate.
* @param results Output span for distances (must match rowIds size).
*/
template <class M, class E, class S>
requires TagInvocable<GatherEvalEncodedTag, M, E, const S&, const std::byte*, std::span<const uint64_t>,
std::span<float>>
constexpr void operator()(const M& m, const E& e, const S& s, const std::byte* recordsBase,
std::span<const uint64_t> rowIds, std::span<float> results) const
noexcept(noexcept(TagInvoke(std::declval<GatherEvalEncodedTag>(), m, e, s, recordsBase, rowIds, results)))
{
TagInvoke(*this, m, e, s, recordsBase, rowIds, results);
}

template <class M, class E, class S, class DataSource>
requires(encode_space::EncodedRowSource<std::remove_cvref_t<DataSource>> &&
TagInvocable<GatherEvalEncodedTag, M, E, const S&, const DataSource&, std::span<const uint64_t>,
std::span<float>>)
constexpr void operator()(const M& m, const E& e, const S& s, const DataSource& data,
std::span<const uint64_t> rowIds, std::span<float> results) const
noexcept(noexcept(TagInvoke(std::declval<GatherEvalEncodedTag>(), m, e, s, data, rowIds, results)))
{
TagInvoke(*this, m, e, s, data, rowIds, results);
}

template <class M, class E, class S, class DataSource>
requires(encode_space::EncodedRowSource<std::remove_cvref_t<DataSource>> &&
!TagInvocable<GatherEvalEncodedTag, M, E, const S&, const DataSource&, std::span<const uint64_t>,
std::span<float>>)
constexpr void operator()(const M& m, const E& e, const S& s, const DataSource& data,
std::span<const uint64_t> rowIds, std::span<float> results) const
noexcept(noexcept(encode_space::GetEncodedRow(data, uint64_t {})) && noexcept(
EvalEncoded(m, e, s, encode_space::GetEncodedRow(data, uint64_t {}))))
{
for (std::size_t i = 0; i < rowIds.size(); ++i) {
const auto row = encode_space::GetEncodedRow(data, rowIds[i]);
results[i] = EvalEncoded(m, e, s, row);
}
}
};
inline constexpr GatherEvalEncodedTag GatherEvalEncoded {};

struct GatherEvalEncodedWithLowerBoundsTag {
/**
* @brief Evaluates distances AND returns a lower-bound estimate for each row.
*
* This is an optimization for multi-stage search (e.g., DiskANN).
*
* Lower bounds (D_lb) guarantee that D_true >= D_lb. In search algorithms, if D_lb is
* already greater than the current search radius, we can skip further processing of this
* candidate.
*
* Note: Not all encodings support lower bounds. If unsupported, use the standard GatherEvalEncoded.
*
* @param results Output span for the (potentially approximate) distances.
* @param lowerBounds Output span for the lower-bound estimates.
*/
template <class M, class E, class S>
requires TagInvocable<GatherEvalEncodedWithLowerBoundsTag, M, E, const S&, const std::byte*,
std::span<const uint64_t>, std::span<float>, std::span<float>>
constexpr void operator()(const M& m, const E& e, const S& s, const std::byte* recordsBase,
std::span<const uint64_t> rowIds, std::span<float> results,
std::span<float> lowerBounds) const
noexcept(noexcept(TagInvoke(std::declval<GatherEvalEncodedWithLowerBoundsTag>(), m, e, s, recordsBase, rowIds,
results, lowerBounds)))
{
TagInvoke(*this, m, e, s, recordsBase, rowIds, results, lowerBounds);
}

template <class M, class E, class S, class DataSource>
requires(encode_space::EncodedRowSource<std::remove_cvref_t<DataSource>> &&
TagInvocable<GatherEvalEncodedWithLowerBoundsTag, M, E, const S&, const DataSource&,
std::span<const uint64_t>, std::span<float>, std::span<float>>)
constexpr void operator()(const M& m, const E& e, const S& s, const DataSource& data,
std::span<const uint64_t> rowIds, std::span<float> results,
std::span<float> lowerBounds) const
noexcept(noexcept(TagInvoke(std::declval<GatherEvalEncodedWithLowerBoundsTag>(), m, e, s, data, rowIds, results,
lowerBounds)))
{
TagInvoke(*this, m, e, s, data, rowIds, results, lowerBounds);
}
};
inline constexpr GatherEvalEncodedWithLowerBoundsTag GatherEvalEncodedWithLowerBounds {};

} // namespace lumina::dist
Loading
Loading