diff --git a/rabitqlib/index/ivf/ivf.hpp b/rabitqlib/index/ivf/ivf.hpp index 895490f..9739b38 100644 --- a/rabitqlib/index/ivf/ivf.hpp +++ b/rabitqlib/index/ivf/ivf.hpp @@ -101,6 +101,8 @@ class IVF { void search(const float*, size_t, size_t, PID*, bool) const; [[nodiscard]] size_t padded_dim() const { return this->padded_dim_; } + + [[nodiscard]] size_t num_clusters() const { return this->num_cluster_; } }; inline IVF::IVF(size_t n, size_t dim, size_t cluster_num, size_t bits, RotatorType type) @@ -374,6 +376,9 @@ inline void IVF::search( PID* __restrict__ results, bool use_hacc = true ) const { + if (nprobe > num_cluster_) { + nprobe = num_cluster_; + } std::vector rotated_query(padded_dim_); this->rotator_->rotate(query, rotated_query.data()); diff --git a/sample/ivf_rabitq_querying.cpp b/sample/ivf_rabitq_querying.cpp index 8a4149f..8e89603 100644 --- a/sample/ivf_rabitq_querying.cpp +++ b/sample/ivf_rabitq_querying.cpp @@ -86,6 +86,10 @@ int main(int argc, char** argv) { for (size_t r = 0; r < test_round; r++) { for (size_t l = 0; l < length; ++l) { size_t nprobe = nprobes[l]; + if (nprobe > ivf.num_clusters()) { + std::cout << "nprobe " << nprobe << " is larger than number of clusters, "; + std::cout << "will use nprobe = num_cluster (" << ivf.num_clusters() << ").\n"; + } size_t total_correct = 0; float total_time = 0; std::vector results(topk);