Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions be/src/storage/index/ann/ann_index_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@
#include "util/once.h"

namespace doris::segment_v2 {
#include "common/compile_check_begin.h"
void AnnIndexReader::update_result(const IndexSearchResult& search_result,
std::vector<float>& distance, roaring::Roaring& roaring) {
DCHECK(search_result.distances != nullptr);
DCHECK(search_result.roaring != nullptr);
size_t limit = search_result.roaring->cardinality();
// Use search result to update distance and row_id
distance.resize(limit);
for (size_t i = 0; i < limit; ++i) {
distance[i] = search_result.distances[i];
}
roaring = *search_result.roaring;
}

AnnIndexReader::AnnIndexReader(const TabletIndex* index_meta,
std::shared_ptr<IndexFileReader> index_file_reader)
: _index_meta(*index_meta), _index_file_reader(index_file_reader) {
Expand Down Expand Up @@ -176,12 +162,12 @@ Status AnnIndexReader::query(io::IOContext* io_ctx, AnnTopNParam* param, AnnInde
DORIS_CHECK(index_search_result.roaring != nullptr);
DORIS_CHECK(index_search_result.distances != nullptr);
DORIS_CHECK(index_search_result.row_ids != nullptr);
param->distance = std::make_unique<std::vector<float>>();
{
SCOPED_TIMER(&(stats->result_process_costs_ns));
update_result(index_search_result, *param->distance, *param->roaring);
param->distance = index_search_result.distances;
*param->roaring = *index_search_result.roaring;
}
param->row_ids = std::move(index_search_result.row_ids);
param->row_ids = index_search_result.row_ids;
}

double search_costs_ms = static_cast<double>(stats->search_costs_ns.value()) / 1000.0;
Expand Down Expand Up @@ -267,13 +253,13 @@ Status AnnIndexReader::range_search(const AnnRangeSearchParams& params,
DCHECK(search_result.row_ids->size() == search_result.roaring->cardinality())
<< "Row ids size: " << search_result.row_ids->size()
<< ", roaring size: " << search_result.roaring->cardinality();
result->row_ids = std::move(search_result.row_ids);
result->row_ids = search_result.row_ids;
} else {
result->row_ids = nullptr;
}

if (search_result.distances != nullptr) {
result->distance = std::move(search_result.distances);
result->distance = search_result.distances;
} else {
result->distance = nullptr;
}
Expand Down
3 changes: 0 additions & 3 deletions be/src/storage/index/ann/ann_index_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ class AnnIndexReader : public IndexReader {
std::shared_ptr<IndexFileReader> index_file_reader);
~AnnIndexReader() override = default;

static void update_result(const IndexSearchResult&, std::vector<float>& distance,
roaring::Roaring& row_id);

Status load_index(io::IOContext* io_ctx);

// Try to load index, return true if successful, false if failed
Expand Down
17 changes: 10 additions & 7 deletions be/src/storage/index/ann/ann_search_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
#include <gen_cpp/Metrics_types.h>
#include <gen_cpp/Opcodes_types.h>

#include <memory>
#include <roaring/roaring.hh>
#include <string>
#include <vector>

#include "exec/scan/vector_search_user_params.h"
#include "runtime/runtime_profile.h"
Expand Down Expand Up @@ -116,8 +118,8 @@ struct AnnTopNParam {
doris::VectorSearchUserParams _user_params;
roaring::Roaring* roaring;
size_t rows_of_segment = 0;
std::unique_ptr<std::vector<float>> distance = nullptr;
std::unique_ptr<std::vector<uint64_t>> row_ids = nullptr;
std::shared_ptr<float[]> distance = nullptr;
std::shared_ptr<std::vector<uint64_t>> row_ids = nullptr;
std::unique_ptr<AnnIndexStats> stats = nullptr;
};

Expand All @@ -136,22 +138,23 @@ struct AnnRangeSearchParams {

struct AnnRangeSearchResult {
std::shared_ptr<roaring::Roaring> roaring;
std::unique_ptr<std::vector<uint64_t>> row_ids;
std::unique_ptr<float[]> distance;
std::shared_ptr<std::vector<uint64_t>> row_ids;
std::shared_ptr<float[]> distance;
};

/*
This struct is used to wrap the search result of a vector index.
roaring is a bitmap that contains the row ids that satisfy the search condition.
row_ids is a vector of row ids that are returned by the search, it could be used by virtual_column_iterator to do column filter.
row_ids is an ordered vector of row ids returned by the search. row_ids[i] is aligned with
distances[i], so virtual_column_iterator can map each distance back to its segment row id.
distances is a vector of distances that are returned by the search.
For range search, is condition is not le_or_lt, the row_ids and distances will be nullptr.
*/
struct IndexSearchResult {
IndexSearchResult() = default;

std::unique_ptr<float[]> distances = nullptr;
std::unique_ptr<std::vector<uint64_t>> row_ids = nullptr;
std::shared_ptr<float[]> distances = nullptr;
std::shared_ptr<std::vector<uint64_t>> row_ids = nullptr;
std::shared_ptr<roaring::Roaring> roaring = nullptr;
// Internal engine timings (ns)
int64_t engine_search_ns = 0; // time spent in the underlying index search call
Expand Down
8 changes: 4 additions & 4 deletions be/src/storage/index/ann/ann_topn_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_des
Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator* ann_index_iterator,
roaring::Roaring* roaring, size_t rows_of_segment,
IColumn::MutablePtr& result_column,
std::unique_ptr<std::vector<uint64_t>>& row_ids,
std::shared_ptr<std::vector<uint64_t>>& row_ids,
segment_v2::AnnIndexStats& ann_index_stats) {
DCHECK(ann_index_iterator != nullptr);
DCHECK(_order_by_expr_ctx != nullptr);
Expand Down Expand Up @@ -220,13 +220,13 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator*
DCHECK(ann_query_params.distance != nullptr);
DCHECK(ann_query_params.row_ids != nullptr);

size_t num_results = ann_query_params.distance->size();
size_t num_results = ann_query_params.row_ids->size();
auto result_column_float = ColumnFloat32::create(num_results);
for (size_t i = 0; i < num_results; ++i) {
result_column_float->get_data()[i] = (*ann_query_params.distance)[i];
result_column_float->get_data()[i] = ann_query_params.distance[i];
}
result_column = std::move(result_column_float);
row_ids = std::move(ann_query_params.row_ids);
row_ids = ann_query_params.row_ids;
ann_index_stats = *ann_query_params.stats;
return Status::OK();
}
Expand Down
2 changes: 1 addition & 1 deletion be/src/storage/index/ann/ann_topn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class AnnTopNRuntime {
Status evaluate_vector_ann_search(segment_v2::AnnIndexIterator* ann_index_iterator,
roaring::Roaring* row_bitmap, size_t rows_of_segment,
IColumn::MutablePtr& result_column,
std::unique_ptr<std::vector<uint64_t>>& row_ids,
std::shared_ptr<std::vector<uint64_t>>& row_ids,
segment_v2::AnnIndexStats& ann_index_stats);

/**
Expand Down
20 changes: 10 additions & 10 deletions be/src/storage/index/ann/faiss_ann_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,8 @@ doris::Status FaissVectorIndex::ann_topn_search(const float* query_vec, int k,
result.roaring = std::make_shared<roaring::Roaring>();
update_roaring(labels, k, *result.roaring);
size_t roaring_cardinality = result.roaring->cardinality();
result.distances = std::make_unique<float[]>(roaring_cardinality);
result.row_ids = std::make_unique<std::vector<uint64_t>>();
result.distances = std::shared_ptr<float[]>(new float[roaring_cardinality]);
result.row_ids = std::make_shared<std::vector<uint64_t>>();
result.row_ids->resize(roaring_cardinality);

if (_metric == AnnIndexMetric::L2) {
Expand Down Expand Up @@ -837,17 +837,17 @@ doris::Status FaissVectorIndex::range_search(const float* query_vec, const float

size_t begin = native_search_result.lims[0];
size_t end = native_search_result.lims[1];
auto row_ids = std::make_unique<std::vector<uint64_t>>();
auto row_ids = std::make_shared<std::vector<uint64_t>>();
row_ids->resize(end - begin);
if (params.is_le_or_lt) {
if (_metric == AnnIndexMetric::L2) {
std::unique_ptr<float[]> distances_ptr;
std::shared_ptr<float[]> distances_ptr;
float* distances = nullptr;
auto roaring = std::make_shared<roaring::Roaring>();
{
// Engine convert: build roaring, row_ids, distances from FAISS result
SCOPED_RAW_TIMER(&result.engine_convert_ns);
distances_ptr = std::make_unique<float[]>(end - begin);
distances_ptr = std::shared_ptr<float[]>(new float[end - begin]);
distances = distances_ptr.get();
// The distance returned by Faiss is actually the squared distance.
// So we need to take the square root of the squared distance.
Expand All @@ -857,8 +857,8 @@ doris::Status FaissVectorIndex::range_search(const float* query_vec, const float
distances[i - begin] = sqrt(native_search_result.distances[i]);
}
}
result.distances = std::move(distances_ptr);
result.row_ids = std::move(row_ids);
result.distances = distances_ptr;
result.row_ids = row_ids;
result.roaring = roaring;

DCHECK(result.row_ids->size() == result.roaring->cardinality())
Expand Down Expand Up @@ -908,7 +908,7 @@ doris::Status FaissVectorIndex::range_search(const float* query_vec, const float
// For inner product, we can use the distance directly.
// range search on ip gets all vectors with inner product greater than or equal to the radius.
// when query condition is not le_or_lt, we can use the roaring and distance directly.
std::unique_ptr<float[]> distances_ptr = std::make_unique<float[]>(end - begin);
std::shared_ptr<float[]> distances_ptr(new float[end - begin]);
float* distances = distances_ptr.get();
auto roaring = std::make_shared<roaring::Roaring>();
// The distance returned by Faiss is actually the squared distance.
Expand All @@ -918,8 +918,8 @@ doris::Status FaissVectorIndex::range_search(const float* query_vec, const float
roaring->add(cast_set<UInt32>(native_search_result.labels[i]));
distances[i - begin] = native_search_result.distances[i];
}
result.distances = std::move(distances_ptr);
result.row_ids = std::move(row_ids);
result.distances = distances_ptr;
result.row_ids = row_ids;
result.roaring = roaring;

DCHECK(result.row_ids->size() == result.roaring->cardinality())
Expand Down
5 changes: 2 additions & 3 deletions be/src/storage/segment/segment_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
return Status::OK();
}
IColumn::MutablePtr result_column;
std::unique_ptr<std::vector<uint64_t>> result_row_ids;
std::shared_ptr<std::vector<uint64_t>> result_row_ids;
segment_v2::AnnIndexStats ann_index_stats;

// Try to load ANN index before search
Expand Down Expand Up @@ -976,8 +976,7 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
"Virtual column iterator, column_idx {}, is materialized with {} rows", dst_col_idx,
result_row_ids->size());
// reference count of result_column should be 1, so move will not issue any data copy.
virtual_column_iter->prepare_materialization(std::move(result_column),
std::move(result_row_ids));
virtual_column_iter->prepare_materialization(std::move(result_column), result_row_ids);

_need_read_data_indices[src_cid] = false;
VLOG_DEBUG << fmt::format(
Expand Down
4 changes: 2 additions & 2 deletions be/src/storage/segment/virtual_column_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Status VirtualColumnIterator::init(const ColumnIteratorOptions& opts) {
}

void VirtualColumnIterator::prepare_materialization(IColumn::Ptr column,
std::unique_ptr<std::vector<uint64_t>> labels) {
std::shared_ptr<std::vector<uint64_t>> labels) {
DCHECK(labels->size() == column->size()) << "labels size: " << labels->size()
<< ", materialized column size: " << column->size();
// 1. do sort to labels
Expand Down Expand Up @@ -165,4 +165,4 @@ Status VirtualColumnIterator::read_by_rowids(const rowid_t* rowids, const size_t
return Status::OK();
}

} // namespace doris::segment_v2
} // namespace doris::segment_v2
4 changes: 2 additions & 2 deletions be/src/storage/segment/virtual_column_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class VirtualColumnIterator : public ColumnIterator {
~VirtualColumnIterator() override = default;

MOCK_FUNCTION void prepare_materialization(IColumn::Ptr column,
std::unique_ptr<std::vector<uint64_t>> labels);
std::shared_ptr<std::vector<uint64_t>> labels);

Status init(const ColumnIteratorOptions& opts) override;

Expand All @@ -61,4 +61,4 @@ class VirtualColumnIterator : public ColumnIterator {
ordinal_t _current_ordinal = 0;
};

} // namespace doris::segment_v2
} // namespace doris::segment_v2
36 changes: 0 additions & 36 deletions be/test/storage/index/ann/ann_index_reader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,42 +384,6 @@ TEST_F(AnnIndexReaderTest, TestRangeSearchIVFWithoutLoadIndex) {
}
}

TEST_F(AnnIndexReaderTest, TestUpdateResultStatic) {
// Test the static update_result method
segment_v2::IndexSearchResult search_result;

// Set up test data
auto roaring = std::make_shared<roaring::Roaring>();
roaring->add(10);
roaring->add(20);
roaring->add(30);

size_t num_results = 3;
auto distances = std::make_unique<float[]>(num_results);
distances[0] = 1.5f;
distances[1] = 2.3f;
distances[2] = 3.1f;

search_result.roaring = roaring;
search_result.distances = std::move(distances);

// Call update_result
std::vector<float> distance_vec;
roaring::Roaring result_roaring;

segment_v2::AnnIndexReader::update_result(search_result, distance_vec, result_roaring);

// Verify results
EXPECT_EQ(distance_vec.size(), num_results);
EXPECT_FLOAT_EQ(distance_vec[0], 1.5f);
EXPECT_FLOAT_EQ(distance_vec[1], 2.3f);
EXPECT_FLOAT_EQ(distance_vec[2], 3.1f);
EXPECT_EQ(result_roaring.cardinality(), num_results);
EXPECT_TRUE(result_roaring.contains(10));
EXPECT_TRUE(result_roaring.contains(20));
EXPECT_TRUE(result_roaring.contains(30));
}

TEST_F(AnnIndexReaderTest, TestRangeSearchWithDifferentParameters) {
auto reader = std::make_unique<segment_v2::AnnIndexReader>(_tablet_index.get(),
_mock_index_file_reader);
Expand Down
Loading
Loading