Skip to content

Dorer SQ8 dist functions [MOD-9626] #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Jun 9, 2025
Merged

Conversation

dor-forer
Copy link
Collaborator

@dor-forer dor-forer commented May 12, 2025

Describe the changes in the pull request

Add dist functions and tests to support SQ8

Which issues this PR fixes

  1. MOD-9626

Main objects this PR modified

  1. dist functions

Mark if applicable

  • This PR introduces API changes
  • This PR introduces serialization changes

@dor-forer dor-forer requested review from lerman25 and Copilot May 22, 2025 07:22
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for SQ8 (8-bit quantized) distance functions for L2, inner product, and cosine similarities across various architectures, enhancing SIMD-based performance.

  • Introduces SQ8_L2Sqr and its SIMD implementations (SVE, SSE4, NEON, AVX2, AVX512, AVX2_FMA)
  • Adds SQ8_InnerProduct and SQ8_Cosine with matching SIMD variants and a common dispatcher in IP_space.cpp
  • Updates headers (L2.h, IP.h), build scripts (CMakeLists, instruction flags) for new functions and compiler flags

Reviewed Changes

Copilot reviewed 44 out of 44 changed files in this pull request and generated no comments.

Show a summary per file
File Description
src/VecSim/spaces/L2/L2_SVE_SQ8.h Adds SVE-based SQ8 L2 squared distance step and template
src/VecSim/spaces/L2/L2_SSE4_SQ8.h Adds SSE4-based SQ8 L2 squared distance
src/VecSim/spaces/L2/L2_NEON_SQ8.h Adds NEON-based SQ8 L2 squared distance
src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8.h Adds AVX512F BW VL VNNI SQ8 L2 squared distance
src/VecSim/spaces/L2/L2_AVX2_SQ8.h Adds AVX2-based SQ8 L2 squared distance
src/VecSim/spaces/L2/L2_AVX2_FMA_SQ8.h Adds AVX2+FMA SQ8 L2 squared distance
src/VecSim/spaces/L2/L2.h Declares SQ8_L2Sqr
src/VecSim/spaces/L2/L2.cpp Implements naive SQ8_L2Sqr
src/VecSim/spaces/IP_space.cpp Registers SQ8 inner-product and cosine dispatch functions
src/VecSim/spaces/IP/IP_SVE_SQ8.h Adds SVE-based SQ8 inner-product and cosine
src/VecSim/spaces/IP/IP_SSE4_SQ8.h Adds SSE4-based SQ8 inner-product and cosine
... (other SIMD variants for IP and Cosine)
src/VecSim/spaces/IP/IP.h Declares SQ8_InnerProduct, SQ8_Cosine
src/VecSim/spaces/CMakeLists.txt Enables SSE4 and AVX2+FMA source files with proper flags
cmake/x86_64InstructionFlags.cmake Adds detection and definitions for SSE4 and AVX2_FMA flags
Comments suppressed due to low confidence (6)

src/VecSim/spaces/IP/IP_SVE_SQ8.h:45

  • [nitpick] Avoid naming a variable min which can conflict with std::min; consider renaming to min_val for clarity and to prevent shadowing.
    float min = *(float *)(pVect2 + dimension);

src/VecSim/spaces/L2/L2.cpp:22

  • Typo in comment: 'structred' should be 'structured'.
    // it structred as [quantized values (uint8_t * dim)][min_val (float)][delta

src/VecSim/spaces/L2/L2.h:14

  • [nitpick] No unit tests were added for SQ8_L2Sqr; please add tests covering full-chunk, partial-chunk, and edge-case dimensions to ensure correctness.
float SQ8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension);

src/VecSim/spaces/L2/L2.cpp:13

  • Remove the unused #include <iostream> from this file to avoid unnecessary dependencies.
#include <iostream>

src/VecSim/spaces/IP/IP_SVE_SQ8.h:11

  • The <iostream> header is not used in this implementation, consider removing it to keep the header lightweight.
#include <iostream>

src/VecSim/spaces/L2/L2.h:14

  • Consider adding a corresponding L2_SQ8_GetDistFunc and registering it in the L2 space selector (similar to IP_space) so that the SQ8 L2 implementation can be chosen dynamically.
float SQ8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension);

Copy link
Collaborator

@lerman25 lerman25 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice L2
Some comments

sum = svmla_f32_x(pg, sum, v1, v2_dequant);

// Move to the next set of elements
offset += svcntw();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In "regular" IP/L2 SVE we passed the offset to the function as parameter, I think we should align this behavior here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is passed as a parameter.
Or am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the offset += svcntw(); -> offset += chunk;
Look at IP_SVE_INT8 for reference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see now.
You all did it behind my back :) I didn't do it on my FP32 implementation.
I will align with your implementation.

@dor-forer dor-forer requested a review from lerman25 May 27, 2025 06:17
lerman25
lerman25 previously approved these changes May 27, 2025
Copy link
Collaborator

@lerman25 lerman25 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job 👑
Other than my 2 comments nothing else to add

Comment on lines 34 to 37
__m256 diff_squared = _mm256_mul_ps(diff, diff);

// Add to running sum
sum256 = _mm256_add_ps(sum256, diff_squared);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fmadd is an option

params[2] = inv_norm;

float dist = SQ8_L2Sqr((const void *)v1_orig, (const void *)v2_compressed.data(), dim);
ASSERT_NEAR(dist, 0.0f, 0.01f) << "SQ8_Cosine failed to match expected distance";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is 0.01 enough?
@GuyAv46 WDYT?
Other tests used 0.000001

Copy link
Collaborator Author

@dor-forer dor-forer May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It ip tests pass with 0.000001 and the l2 tests pass with 0.00001

Comment on lines +33 to +34
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider remodelling so the metadata is at the start of the vector

Comment on lines 139 to 140
if (dim % 16 == 0) // no point in aligning if we have an offsetting residual
*alignment = 16 * sizeof(float); // handles 16 floats
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not aligned when including the metadata of the vector, no need to be aligned. If we're not sure what's the alignment should be - it is better to not set it (keep it 0)

float32x4_t v2_f = vcvtq_f32_u32(v2_u32);

// Dequantize: (val * delta) + min_val
float32x4_t v2_dequant = vmlaq_f32(min_val_vec, v2_f, delta_vec);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are intrinsics for MLA with a scalar, removing the need to create min_val_vec and delta_vec. We should check (in the following PR) what's the difference in performance

Comment on lines +80 to +94
if constexpr (final_residual >= 1) {
v1 = vld1q_lane_f32(pVect1, v1, 0);
float dequant0 = pVect2[0] * delta + min_val;
v2_dequant = vld1q_lane_f32(&dequant0, v2_dequant, 0);
}
if constexpr (final_residual >= 2) {
v1 = vld1q_lane_f32(pVect1 + 1, v1, 1);
float dequant1 = pVect2[1] * delta + min_val;
v2_dequant = vld1q_lane_f32(&dequant1, v2_dequant, 1);
}
if constexpr (final_residual >= 3) {
v1 = vld1q_lane_f32(pVect1 + 2, v1, 2);
float dequant2 = pVect2[2] * delta + min_val;
v2_dequant = vld1q_lane_f32(&dequant2, v2_dequant, 2);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be evaluated (performance)

Comment on lines 63 to 68
float dequant0 = quantized[0] * delta + min;
v2_dequant = _mm_load_ss(&dequant0);

// Dequantize next two values
float dequant_high[2] = {quantized[1] * delta + min, quantized[2] * delta + min};
v2_dequant = _mm_loadh_pi(v2_dequant, (__m64 *)dequant_high);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using set amd manually set the relevant elements of the vector, instead of loading from a stack variable

@dor-forer dor-forer requested a review from GuyAv46 June 5, 2025 15:09
Comment on lines 73 to 74
if (dim % 16 == 0) // no point in aligning if we have an offsetting residual
*alignment = 16 * sizeof(float); // handles 16 floats
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove redundant alignment

@dor-forer dor-forer requested review from GuyAv46 and lerman25 June 8, 2025 16:36
Copy link
Collaborator

@GuyAv46 GuyAv46 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still validate the implementation performance and address the performance comments left on this PR

@dor-forer dor-forer added this pull request to the merge queue Jun 9, 2025
Merged via the queue into main with commit 6a84603 Jun 9, 2025
23 checks passed
@dor-forer dor-forer deleted the dorer-sq8-dist-functions branch June 9, 2025 10:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants