diff --git a/rabitqlib/index/ivf/ivf.hpp b/rabitqlib/index/ivf/ivf.hpp index 9739b38..d8bd111 100644 --- a/rabitqlib/index/ivf/ivf.hpp +++ b/rabitqlib/index/ivf/ivf.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -118,9 +119,7 @@ inline IVF::IVF(size_t n, size_t dim, size_t cluster_num, size_t bits, RotatorTy std::cerr.flush(); exit(1); }; - rotator_ = choose_rotator( - dim, RotatorType::FhtKacRotator, round_up_to_multiple(dim_, 64) - ); + rotator_ = choose_rotator(dim, type, round_up_to_multiple(dim_, 64)); padded_dim_ = rotator_->size(); /* check size */ assert(padded_dim_ % 64 == 0); @@ -332,9 +331,7 @@ inline void IVF::load(const char* filename) { input.read(reinterpret_cast(&this->ex_bits_), sizeof(size_t)); input.read(reinterpret_cast(&type_), sizeof(type_)); - rotator_ = choose_rotator( - dim_, RotatorType::FhtKacRotator, round_up_to_multiple(dim_, 64) - ); + rotator_ = choose_rotator(dim_, type_, round_up_to_multiple(dim_, 64)); padded_dim_ = rotator_->size(); /* Load number of vectors of each cluster */ @@ -376,12 +373,12 @@ inline void IVF::search( PID* __restrict__ results, bool use_hacc = true ) const { - if (nprobe > num_cluster_) { - nprobe = num_cluster_; - } + nprobe = std::min(nprobe, num_cluster_); // corner case std::vector rotated_query(padded_dim_); this->rotator_->rotate(query, rotated_query.data()); + std::cout << l2norm_sqr(query, dim_) << '\t' << l2norm_sqr(rotated_query.data(), padded_dim_) << '\n'; + // use initer to get closest nprobe centroids std::vector> centroid_dist(nprobe); this->initer_->centroids_distances(rotated_query.data(), nprobe, centroid_dist); diff --git a/rabitqlib/index/symqg/qg.hpp b/rabitqlib/index/symqg/qg.hpp index 943f38d..7f6aed6 100644 --- a/rabitqlib/index/symqg/qg.hpp +++ b/rabitqlib/index/symqg/qg.hpp @@ -341,9 +341,7 @@ template inline void QuantizedGraph::initialize() { ::delete rotator_; - rotator_ = choose_rotator( - dim_, RotatorType::FhtKacRotator, round_up_to_multiple(dim_, 64) - ); + rotator_ = choose_rotator(dim_, type_, round_up_to_multiple(dim_, 64)); padded_dim_ = rotator_->size(); /* check size */