diff --git a/rabitqlib/quantization/rabitq_impl.hpp b/rabitqlib/quantization/rabitq_impl.hpp index af0371d..ff2ac8a 100644 --- a/rabitqlib/quantization/rabitq_impl.hpp +++ b/rabitqlib/quantization/rabitq_impl.hpp @@ -100,6 +100,11 @@ inline void one_bit_code_with_factor( // dot product between centroid and xu_cb T ip_cent_xucb = dot_product(centroid, xu_cb.data(), dim); + // corner case + if (ip_resi_xucb == 0) { + ip_resi_xucb = std::numeric_limits::infinity(); + } + // We use unnormalized vector to get error factor. To be more specific, // sqrt((1 - ^2) / ^2) / sqrt(dim - 1) = 3rd item in following // expression @@ -465,6 +470,11 @@ inline void ex_bits_code_with_factor( T ip_resi_xucb = dot_product(residual_arr.data(), xu_cb.data(), dim); T ip_cent_xucb = dot_product(centroid, xu_cb.data(), dim); + // corner case + if (ip_resi_xucb == 0) { + ip_resi_xucb = std::numeric_limits::infinity(); + } + T tmp_error = l2_norm * kConstEpsilon * std::sqrt( @@ -556,7 +566,8 @@ static inline void rabitq_scalar_impl( float norm_data = std::sqrt(l2norm_sqr(residual_arr.data(), dim)); float norm_quan = std::sqrt(l2norm_sqr(u_cb.data(), dim)); - float cos_similarity = dot_product(residual_arr.data(), u_cb.data(), dim) / (norm_data * norm_quan); + float cos_similarity = + dot_product(residual_arr.data(), u_cb.data(), dim) / (norm_data * norm_quan); if (scalar_quantizer_type == ScalarQuantizerType::RECONSTRUCTION) { delta = norm_data / norm_quan * cos_similarity;