From 8c6644cce1e72555b517a20966565ccd05b964b3 Mon Sep 17 00:00:00 2001 From: gouyt13clear Date: Thu, 12 Jun 2025 14:11:51 +0800 Subject: [PATCH] handle corner case for factors during quantization --- rabitqlib/quantization/rabitq_impl.hpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/rabitqlib/quantization/rabitq_impl.hpp b/rabitqlib/quantization/rabitq_impl.hpp index af0371d..ff2ac8a 100644 --- a/rabitqlib/quantization/rabitq_impl.hpp +++ b/rabitqlib/quantization/rabitq_impl.hpp @@ -100,6 +100,11 @@ inline void one_bit_code_with_factor( // dot product between centroid and xu_cb T ip_cent_xucb = dot_product(centroid, xu_cb.data(), dim); + // corner case + if (ip_resi_xucb == 0) { + ip_resi_xucb = std::numeric_limits::infinity(); + } + // We use unnormalized vector to get error factor. To be more specific, // sqrt((1 - ^2) / ^2) / sqrt(dim - 1) = 3rd item in following // expression @@ -465,6 +470,11 @@ inline void ex_bits_code_with_factor( T ip_resi_xucb = dot_product(residual_arr.data(), xu_cb.data(), dim); T ip_cent_xucb = dot_product(centroid, xu_cb.data(), dim); + // corner case + if (ip_resi_xucb == 0) { + ip_resi_xucb = std::numeric_limits::infinity(); + } + T tmp_error = l2_norm * kConstEpsilon * std::sqrt( @@ -556,7 +566,8 @@ static inline void rabitq_scalar_impl( float norm_data = std::sqrt(l2norm_sqr(residual_arr.data(), dim)); float norm_quan = std::sqrt(l2norm_sqr(u_cb.data(), dim)); - float cos_similarity = dot_product(residual_arr.data(), u_cb.data(), dim) / (norm_data * norm_quan); + float cos_similarity = + dot_product(residual_arr.data(), u_cb.data(), dim) / (norm_data * norm_quan); if (scalar_quantizer_type == ScalarQuantizerType::RECONSTRUCTION) { delta = norm_data / norm_quan * cos_similarity;