Skip to content

Commit

Permalink
Refactor: Avoid bfloat16x8_t type for MSVC
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jun 30, 2024
1 parent 812747d commit f28295e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
18 changes: 11 additions & 7 deletions include/simsimd/dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,27 +476,31 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_
float32x4_t ab_high_vec = vdupq_n_f32(0), ab_low_vec = vdupq_n_f32(0);
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a + i);
bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b + i);
ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_vec, b_vec);
ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_vec, b_vec);
// MSVC and some other compilers don't support `bfloat16_t` as a native type,
// so we store the integer and reinterpret it as `bf16` afterwards.
int16x8_t a_vec = vld1q_s16((int16_t const*)a + i);
int16x8_t b_vec = vld1q_s16((int16_t const*)b + i);
ab_high_vec = vbfmlaltq_f32(ab_high_vec, vreinterpretq_bf16_s16(a_vec), vreinterpretq_bf16_s16(b_vec));
ab_low_vec = vbfmlalbq_f32(ab_low_vec, vreinterpretq_bf16_s16(a_vec), vreinterpretq_bf16_s16(b_vec));
}

// In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16`
// function will run. It is extremely slow, so even for the tail, let's combine serial
// loads and stores with vectorized math.
if (i < n) {
union {
bfloat16x8_t bf16_vec;
int16x8_t i16_vec;
simsimd_bf16_t bf16[8];
} a_padded_tail, b_padded_tail;
simsimd_size_t j = 0;
for (; i < n; ++i, ++j)
a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i];
for (; j < 8; ++j)
a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0;
ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec);
ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec);
ab_high_vec = vbfmlaltq_f32(ab_high_vec, vreinterpretq_bf16_s16(a_padded_tail.i16_vec),
vreinterpretq_bf16_s16(b_padded_tail.i16_vec));
ab_low_vec = vbfmlalbq_f32(ab_low_vec, vreinterpretq_bf16_s16(a_padded_tail.i16_vec),
vreinterpretq_bf16_s16(b_padded_tail.i16_vec));
}
*result = vaddvq_f32(ab_high_vec) + vaddvq_f32(ab_low_vec);
}
Expand Down
60 changes: 36 additions & 24 deletions include/simsimd/spatial.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,35 +341,43 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_
float32x4_t b2_high_vec = vdupq_n_f32(0), b2_low_vec = vdupq_n_f32(0);
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a + i);
bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b + i);
ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_vec, b_vec);
ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_vec, b_vec);
a2_high_vec = vbfmlaltq_f32(a2_high_vec, a_vec, a_vec);
a2_low_vec = vbfmlalbq_f32(a2_low_vec, a_vec, a_vec);
b2_high_vec = vbfmlaltq_f32(b2_high_vec, b_vec, b_vec);
b2_low_vec = vbfmlalbq_f32(b2_low_vec, b_vec, b_vec);
// MSVC and some other compilers don't support `bfloat16_t` as a native type,
// so we store the integer and reinterpret it as `bf16` afterwards.
int16x8_t a_vec = vld1q_s16((int16_t const*)a + i);
int16x8_t b_vec = vld1q_s16((int16_t const*)b + i);
ab_high_vec = vbfmlaltq_f32(ab_high_vec, vreinterpretq_bf16_s16(a_vec), vreinterpretq_bf16_s16(b_vec));
ab_low_vec = vbfmlalbq_f32(ab_low_vec, vreinterpretq_bf16_s16(a_vec), vreinterpretq_bf16_s16(b_vec));
a2_high_vec = vbfmlaltq_f32(a2_high_vec, vreinterpretq_bf16_s16(a_vec), vreinterpretq_bf16_s16(a_vec));
a2_low_vec = vbfmlalbq_f32(a2_low_vec, vreinterpretq_bf16_s16(a_vec), vreinterpretq_bf16_s16(a_vec));
b2_high_vec = vbfmlaltq_f32(b2_high_vec, vreinterpretq_bf16_s16(b_vec), vreinterpretq_bf16_s16(b_vec));
b2_low_vec = vbfmlalbq_f32(b2_low_vec, vreinterpretq_bf16_s16(b_vec), vreinterpretq_bf16_s16(b_vec));
}

// In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16`
// function will run. It is extremely slow, so even for the tail, let's combine serial
// loads and stores with vectorized math.
if (i < n) {
union {
bfloat16x8_t bf16_vec;
int16x8_t i16_vec;
simsimd_bf16_t bf16[8];
} a_padded_tail, b_padded_tail;
simsimd_size_t j = 0;
for (; i < n; ++i, ++j)
a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i];
for (; j < 8; ++j)
a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0;
ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec);
ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec);
a2_high_vec = vbfmlaltq_f32(a2_high_vec, a_padded_tail.bf16_vec, a_padded_tail.bf16_vec);
a2_low_vec = vbfmlalbq_f32(a2_low_vec, a_padded_tail.bf16_vec, a_padded_tail.bf16_vec);
b2_high_vec = vbfmlaltq_f32(b2_high_vec, b_padded_tail.bf16_vec, b_padded_tail.bf16_vec);
b2_low_vec = vbfmlalbq_f32(b2_low_vec, b_padded_tail.bf16_vec, b_padded_tail.bf16_vec);
ab_high_vec = vbfmlaltq_f32(ab_high_vec, vreinterpretq_bf16_s16(a_padded_tail.i16_vec),
vreinterpretq_bf16_s16(b_padded_tail.i16_vec));
ab_low_vec = vbfmlalbq_f32(ab_low_vec, vreinterpretq_bf16_s16(a_padded_tail.i16_vec),
vreinterpretq_bf16_s16(b_padded_tail.i16_vec));
a2_high_vec = vbfmlaltq_f32(a2_high_vec, vreinterpretq_bf16_s16(a_padded_tail.i16_vec),
vreinterpretq_bf16_s16(a_padded_tail.i16_vec));
a2_low_vec = vbfmlalbq_f32(a2_low_vec, vreinterpretq_bf16_s16(a_padded_tail.i16_vec),
vreinterpretq_bf16_s16(a_padded_tail.i16_vec));
b2_high_vec = vbfmlaltq_f32(b2_high_vec, vreinterpretq_bf16_s16(b_padded_tail.i16_vec),
vreinterpretq_bf16_s16(b_padded_tail.i16_vec));
b2_low_vec = vbfmlalbq_f32(b2_low_vec, vreinterpretq_bf16_s16(b_padded_tail.i16_vec),
vreinterpretq_bf16_s16(b_padded_tail.i16_vec));
}

// Avoid `simsimd_approximate_inverse_square_root` on Arm NEON
Expand All @@ -389,12 +397,16 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16
float32x4_t sum_high_vec = vdupq_n_f32(0), sum_low_vec = vdupq_n_f32(0);
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a + i);
bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b + i);
// MSVC and some other compilers don't support `bfloat16_t` as a native type,
// so we store the integer and reinterpret it as `bf16` afterwards.
int16x8_t a_vec = vld1q_s16((int16_t const*)a + i);
int16x8_t b_vec = vld1q_s16((int16_t const*)b + i);
// We can't perform subtraction in `bf16`. One option would be to upcast to `f32`
// and then subtract, converting back to `bf16` for computing the squared difference.
diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(a_vec)), vcvt_f32_bf16(vget_high_bf16(b_vec)));
diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(a_vec)), vcvt_f32_bf16(vget_low_bf16(b_vec)));
diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(vreinterpretq_bf16_s16(a_vec))),
vcvt_f32_bf16(vget_high_bf16(vreinterpretq_bf16_s16(b_vec))));
diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(vreinterpretq_bf16_s16(a_vec))),
vcvt_f32_bf16(vget_low_bf16(vreinterpretq_bf16_s16(b_vec))));
sum_high_vec = vfmaq_f32(sum_high_vec, diff_high_vec, diff_high_vec);
sum_low_vec = vfmaq_f32(sum_low_vec, diff_low_vec, diff_low_vec);
}
Expand All @@ -404,18 +416,18 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16
// loads and stores with vectorized math.
if (i < n) {
union {
bfloat16x8_t bf16_vec;
int16x8_t i16_vec;
simsimd_bf16_t bf16[8];
} a_padded_tail, b_padded_tail;
simsimd_size_t j = 0;
for (; i < n; ++i, ++j)
a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i];
for (; j < 8; ++j)
a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0;
diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(a_padded_tail.bf16_vec)),
vcvt_f32_bf16(vget_high_bf16(b_padded_tail.bf16_vec)));
diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(a_padded_tail.bf16_vec)),
vcvt_f32_bf16(vget_low_bf16(b_padded_tail.bf16_vec)));
diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(vreinterpretq_bf16_s16(a_padded_tail.i16_vec))),
vcvt_f32_bf16(vget_high_bf16(vreinterpretq_bf16_s16(b_padded_tail.i16_vec))));
diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(vreinterpretq_bf16_s16(a_padded_tail.i16_vec))),
vcvt_f32_bf16(vget_low_bf16(vreinterpretq_bf16_s16(b_padded_tail.i16_vec))));
sum_high_vec = vfmaq_f32(sum_high_vec, diff_high_vec, diff_high_vec);
sum_low_vec = vfmaq_f32(sum_low_vec, diff_low_vec, diff_low_vec);
}
Expand Down

0 comments on commit f28295e

Please sign in to comment.