From 6fcaf3c0665dcd2e34486ad6199de58e13080045 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 1 May 2024 14:04:46 +0000 Subject: [PATCH] Add NEON-accelerated int8mm for bfloat16 (#125290) As apparently `vshlq_u32` is faster than `vcvt_f32_f16` Refactor NEON `tinygemm_kernel` to rely on `load_as_float32x4` and `load_as_float32x4x2` and implement them for float16 (using vcvt), bfloat16 (using left shift) and plain float32 (not using anything) As result stories110M run at 60 tokens/sec with f16, but at 66 tokens/sec with bf16 and 75 tokens/sec with f32, though more bandwith demand starts to favor reduced floating types as model size gets bigger. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125290 Approved by: https://github.com/mikekgfb --- aten/src/ATen/native/cpu/int8mm_kernel.cpp | 94 +++++++++++++++++++--- 1 file changed, 82 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/cpu/int8mm_kernel.cpp b/aten/src/ATen/native/cpu/int8mm_kernel.cpp index 4ef6cde4a8799..bd266030b2566 100644 --- a/aten/src/ATen/native/cpu/int8mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int8mm_kernel.cpp @@ -185,17 +185,50 @@ inline void tinygemm_kernel( #if !defined(C10_MOBILE) && defined(__aarch64__) #include -static inline float reduce(float32x4_t x) { +inline float reduce(float32x4_t x) { auto sum = vpaddq_f32(x, x); return vgetq_lane_f32(vpaddq_f32(sum, sum), 0); } -template -inline void tinygemm_kernel( - const Half* RESTRICT A, +inline float32x4x2_t load_as_float32x4x2(const Half* ptr) { + float16x8_t f16_val = vld1q_f16(reinterpret_cast(ptr)); + auto val_low = vcvt_f32_f16(vget_low_f16(f16_val)); + auto val_high = vcvt_f32_f16(vget_high_f16(f16_val)); + return {val_low, val_high}; +} + +inline float32x4_t load_as_float32x4(const Half* ptr) { + return vcvt_f32_f16(vld1_f16(reinterpret_cast(ptr))); +} + +inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) { + int32x4_t shift = vdupq_n_s32(16); + uint16x8_t u16_val = vld1q_u16(reinterpret_cast(ptr)); + uint32x4_t int_low = vmovl_u16(vget_low_u16(u16_val)); + uint32x4_t int_high = vmovl_u16(vget_high_u16(u16_val)); + return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))}; +} + +inline float32x4_t load_as_float32x4(const BFloat16* ptr) { + int32x4_t shift = vdupq_n_s32(16); + uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast(ptr))); + return vreinterpretq_f32_u32(vshlq_u32(as_int, shift)); +} + +inline float32x4_t load_as_float32x4(const float* ptr) { + return vld1q_f32(ptr); +} + +inline float32x4x2_t load_as_float32x4x2(const float* ptr) { + return {vld1q_f32(ptr), vld1q_f32(ptr + 4)}; +} + +template +inline void tinygemm_kernel_( + const T* RESTRICT A, const int8_t* RESTRICT B, - const Half* RESTRICT scales, - Half* RESTRICT C, + const T* RESTRICT scales, + T* RESTRICT C, int lda, int ldb, int ldc, @@ -207,24 +240,61 @@ inline void tinygemm_kernel( c_val[i] = vdupq_n_f32(0.0); }); for (int k = 0; k < K; k += 8) { - float16x8_t a_val = vld1q_f16(reinterpret_cast(A) + m * lda + k); - auto a_val_low = vcvt_f32_f16(vget_low_f16(a_val)); - auto a_val_high = vcvt_f32_f16(vget_high_f16(a_val)); + auto a_val = load_as_float32x4x2(A + m * lda + k); c10::ForcedUnroll{}([&](auto i) { int16x8_t b_val = vmovl_s8(vld1_s8(B + i * ldb + k)); auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val))); auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val))); - c_val[i] = vfmaq_f32(c_val[i], a_val_high, b_val_high); - c_val[i] = vfmaq_f32(c_val[i], a_val_low, b_val_low); + c_val[i] = vfmaq_f32(c_val[i], a_val.val[1], b_val_high); + c_val[i] = vfmaq_f32(c_val[i], a_val.val[0], b_val_low); }); } - float32x4_t scale_val = vcvt_f32_f16(vld1_f16(reinterpret_cast(scales))); + float32x4_t scale_val = load_as_float32x4(scales); c10::ForcedUnroll{}([&](auto i) { C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i); }); } } + +template +inline void tinygemm_kernel( + const Half* RESTRICT A, + const int8_t* RESTRICT B, + const Half* RESTRICT scales, + Half* RESTRICT C, + int lda, + int ldb, + int ldc, + int K) { + tinygemm_kernel_(A, B, scales, C, lda, ldb, ldc, K); +} + +template +inline void tinygemm_kernel( + const BFloat16* RESTRICT A, + const int8_t* RESTRICT B, + const BFloat16* RESTRICT scales, + BFloat16* RESTRICT C, + int lda, + int ldb, + int ldc, + int K) { + tinygemm_kernel_(A, B, scales, C, lda, ldb, ldc, K); +} + +template +inline void tinygemm_kernel( + const float* RESTRICT A, + const int8_t* RESTRICT B, + const float* RESTRICT scales, + float* RESTRICT C, + int lda, + int ldb, + int ldc, + int K) { + tinygemm_kernel_(A, B, scales, C, lda, ldb, ldc, K); +} #endif // non-vectorized version