Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion be/src/storage/index/ann/faiss_ann_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace doris::segment_v2 {
namespace {

std::mutex g_omp_thread_mutex;
std::condition_variable g_omp_thread_cv;
int g_index_threads_in_use = 0;

// Guard that ensures the total OpenMP threads used by concurrent index builds
Expand All @@ -71,7 +72,11 @@ class ScopedOmpThreadBudget {
// For each index build, reserve at most half of the remaining threads, at least 1 thread.
ScopedOmpThreadBudget() {
std::unique_lock<std::mutex> lock(g_omp_thread_mutex);
auto thread_cap = config::omp_threads_limit - g_index_threads_in_use;
auto omp_threads_limit = get_omp_threads_limit();
// Block until there is at least one OpenMP slot available under the global cap.
g_omp_thread_cv.wait(lock, [&] { return g_index_threads_in_use < omp_threads_limit; });
auto thread_cap = omp_threads_limit - g_index_threads_in_use;
// Keep headroom for other concurrent index builds: take up to half of remaining budget.
_reserved_threads = std::max(1, thread_cap / 2);
g_index_threads_in_use += _reserved_threads;
DorisMetrics::instance()->ann_index_build_index_threads->increment(_reserved_threads);
Expand All @@ -88,6 +93,8 @@ class ScopedOmpThreadBudget {
if (g_index_threads_in_use < 0) {
g_index_threads_in_use = 0;
}
// Wake waiting index builders so they can compete for the released OpenMP budget.
g_omp_thread_cv.notify_all();
VLOG_DEBUG << fmt::format(
"ScopedOmpThreadBudget release threads reserved={}, remaining_in_use={}, limit={}",
_reserved_threads, g_index_threads_in_use, get_omp_threads_limit());
Expand Down
65 changes: 65 additions & 0 deletions be/test/storage/index/ann/faiss_vector_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@
#include <gtest/gtest.h>

#include <algorithm>
#include <atomic>
#include <chrono>
#include <cstddef>
#include <limits>
#include <memory>
#include <random>
#include <string>
#include <thread>
#include <vector>

#include "common/config.h"
#include "common/metrics/doris_metrics.h"
#include "storage/index/ann/ann_index.h"
#include "storage/index/ann/ann_search_params.h"
#include "storage/index/ann/faiss_ann_index.h"
// metrics.h not used directly here
#include "storage/index/ann/vector_search_utils.h"
#include "util/defer_op.h"

using namespace doris::segment_v2;

Expand Down Expand Up @@ -233,6 +239,65 @@ TEST_F(VectorSearchTest, UpdateRoaring) {
}
}

TEST_F(VectorSearchTest, OmpThreadBudgetNeverExceedsLimit) {
constexpr int kWorkers = 2;
constexpr int kDim = 64;
// Keep this workload small to avoid long-running BE UT under ASAN.
constexpr int kNumVectors = 500;

const auto old_omp_threads_limit = config::omp_threads_limit;
config::omp_threads_limit = 1;
Defer reset_omp_threads_limit(
[&old_omp_threads_limit]() { config::omp_threads_limit = old_omp_threads_limit; });

auto* budget_metric = DorisMetrics::instance()->ann_index_build_index_threads;
std::atomic<bool> start {false};
std::atomic<int> finished {0};
std::vector<std::thread> workers;
workers.reserve(kWorkers);

for (int worker_id = 0; worker_id < kWorkers; ++worker_id) {
workers.emplace_back([&start, &finished, worker_id]() {
auto index = std::make_unique<FaissVectorIndex>();
FaissBuildParameter params;
params.dim = kDim;
params.max_degree = 8;
params.ef_construction = 20;
params.index_type = FaissBuildParameter::IndexType::HNSW;
index->build(params);

std::vector<float> vectors(static_cast<size_t>(kNumVectors) * kDim,
static_cast<float>(worker_id + 1));
while (!start.load(std::memory_order_acquire)) {
std::this_thread::yield();
}

auto st = index->add(kNumVectors, vectors.data());
EXPECT_TRUE(st.ok()) << st.to_string();
finished.fetch_add(1, std::memory_order_acq_rel);
});
}

start.store(true, std::memory_order_release);

int64_t observed_peak = 0;
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(20);
while (finished.load(std::memory_order_acquire) < kWorkers &&
std::chrono::steady_clock::now() < deadline) {
observed_peak = std::max<int64_t>(observed_peak, budget_metric->value());
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}

for (auto& worker : workers) {
worker.join();
}

observed_peak = std::max<int64_t>(observed_peak, budget_metric->value());
EXPECT_EQ(finished.load(std::memory_order_acquire), kWorkers);
EXPECT_LE(observed_peak, 1);
EXPECT_EQ(budget_metric->value(), 0);
}

TEST_F(VectorSearchTest, CompareResultWithNativeFaiss1) {
const size_t iterations = 3;
// Create random number generator
Expand Down
Loading