Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions rabitqlib/index/ivf/ivf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class IVF {
void search(const float*, size_t, size_t, PID*, bool) const;

[[nodiscard]] size_t padded_dim() const { return this->padded_dim_; }

[[nodiscard]] size_t num_clusters() const { return this->num_cluster_; }
};

inline IVF::IVF(size_t n, size_t dim, size_t cluster_num, size_t bits, RotatorType type)
Expand Down Expand Up @@ -374,6 +376,9 @@ inline void IVF::search(
PID* __restrict__ results,
bool use_hacc = true
) const {
if (nprobe > num_cluster_) {
nprobe = num_cluster_;
}
std::vector<float> rotated_query(padded_dim_);
this->rotator_->rotate(query, rotated_query.data());

Expand Down
4 changes: 4 additions & 0 deletions sample/ivf_rabitq_querying.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ int main(int argc, char** argv) {
for (size_t r = 0; r < test_round; r++) {
for (size_t l = 0; l < length; ++l) {
size_t nprobe = nprobes[l];
if (nprobe > ivf.num_clusters()) {
std::cout << "nprobe " << nprobe << " is larger than number of clusters, ";
std::cout << "will use nprobe = num_cluster (" << ivf.num_clusters() << ").\n";
}
size_t total_correct = 0;
float total_time = 0;
std::vector<PID> results(topk);
Expand Down