Skip to content

Commit

Permalink
[Perf] Vectorize more dtype for int4mm (pytorch#126512)
Browse files Browse the repository at this point in the history
It used to be vectorized only for f16, but no reason not to do the same for bf16 or f32

Spiritual followup of pytorch#125290

Pull Request resolved: pytorch#126512
Approved by: https://github.com/Skylion007
  • Loading branch information
malfet authored and ZelboK committed May 19, 2024
1 parent bd10ff6 commit e24f7b3
Showing 1 changed file with 85 additions and 9 deletions.
94 changes: 85 additions & 9 deletions aten/src/ATen/native/cpu/int4mm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,46 @@ inline void tinygemm_kernel(

#if !defined(C10_MOBILE) && defined(__aarch64__)
#include <arm_neon.h>
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const Half* RESTRICT A,

inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
float16x4x2_t f16_val = vld2_f16(reinterpret_cast<const float16_t *>(ptr));
auto val_low = vcvt_f32_f16(f16_val.val[0]);
auto val_high = vcvt_f32_f16(f16_val.val[1]);
return {val_low, val_high};
}

inline void store_float32x4(Half* ptr, float32x4_t val) {
vst1_f16(reinterpret_cast<float16_t*>(ptr), vcvt_f16_f32(val));
}

inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint16x4x2_t u16_val = vld2_u16(reinterpret_cast<const uint16_t *>(ptr));
uint32x4_t int_low = vmovl_u16(u16_val.val[0]);
uint32x4_t int_high = vmovl_u16(u16_val.val[1]);
return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
}

inline void store_float32x4(BFloat16* ptr, float32x4_t val) {
int32x4_t shift = vdupq_n_s32(-16);
uint32x4_t uint32_val = vshlq_u32(vreinterpretq_u32_f32(val), shift);
vst1_u16(reinterpret_cast<uint16_t*>(ptr), vmovn_u32(uint32_val));
}

inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
return vld2q_f32(ptr);
}

inline void store_float32x4(float* ptr, float32x4_t val) {
vst1q_f32(ptr, val);
}

template <int BLOCK_M, int BLOCK_N, typename T>
inline void tinygemm_kernel_(
const T* RESTRICT A,
const uint8_t* RESTRICT B,
const Half* RESTRICT ScaleAndZeros,
Half* RESTRICT C,
const T* RESTRICT ScaleAndZeros,
T* RESTRICT C,
int lda,
int ldb,
int ldc,
Expand All @@ -368,9 +402,9 @@ inline void tinygemm_kernel(
if (is_block_start(k, BLOCK_K)) {
int kb = k / BLOCK_K;
c10::ForcedUnroll<4>{}([&](auto i) {
auto scales_and_zeros = vld2_f16(reinterpret_cast<const float16_t*>(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8));
scales[i] = vcvt_f32_f16(scales_and_zeros.val[0]);
zeros[i] = vcvt_f32_f16(scales_and_zeros.val[1]);
auto scales_and_zeros = load_as_float32x4x2(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8);
scales[i] = scales_and_zeros.val[0];
zeros[i] = scales_and_zeros.val[1];
});
}
c10::ForcedUnroll<4>{}([&](auto i) {
Expand All @@ -383,11 +417,53 @@ inline void tinygemm_kernel(
});
}
c10::ForcedUnroll<4>{}([&](auto i) {
vst1_f16(reinterpret_cast<float16_t*>(C + m * ldc + n + i * 4), vcvt_f16_f32(c_val[i]));
store_float32x4(C + m * ldc + n + i * 4, c_val[i]);
});
}
}
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const Half* RESTRICT A,
const uint8_t* RESTRICT B,
const Half* RESTRICT ScaleAndZeros,
Half* RESTRICT C,
int lda,
int ldb,
int ldc,
int K,
int BLOCK_K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const uint8_t* RESTRICT B,
const BFloat16* RESTRICT ScaleAndZeros,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K,
int BLOCK_K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const float* RESTRICT A,
const uint8_t* RESTRICT B,
const float* RESTRICT ScaleAndZeros,
float* RESTRICT C,
int lda,
int ldb,
int ldc,
int K,
int BLOCK_K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
}
#endif

template<int BLOCK_N>
Expand Down

0 comments on commit e24f7b3

Please sign in to comment.