diff --git a/kernel/arm64/KERNEL.NEOVERSEN2 b/kernel/arm64/KERNEL.NEOVERSEN2 index 5ada34e6bb..6431422faa 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN2 +++ b/kernel/arm64/KERNEL.NEOVERSEN2 @@ -189,7 +189,7 @@ ZGEMMONCOPYOBJ = zgemm_oncopy$(TSUFFIX).$(SUFFIX) ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX) ifeq ($(BUILD_BFLOAT16), 1) -BGEMM_BETA = sbgemm_beta_neoversen2.c +BGEMM_BETA = bgemm_beta_neon.c BGEMMKERNEL = sbgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversen2.c BGEMMINCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_M)_neoversen2.c BGEMMITCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_M)_neoversen2.c diff --git a/kernel/arm64/bgemm_kernel_2vlx4_neoversev1_impl.c b/kernel/arm64/bgemm_kernel_2vlx4_neoversev1_impl.c index 215d0d717a..6749c7948f 100644 --- a/kernel/arm64/bgemm_kernel_2vlx4_neoversev1_impl.c +++ b/kernel/arm64/bgemm_kernel_2vlx4_neoversev1_impl.c @@ -40,7 +40,8 @@ #define UPDATE_C(PG, PTR, DST, SRC) \ do { \ - DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \ + svtmp16 = svld1_bf16((pghalf), (PTR)); \ + DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \ DST = svadd_z((PG), SRC, DST); \ svtmp16 = svcvt_bf16_f32_z((PG), DST); \ svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \ @@ -55,7 +56,8 @@ #define UPDATE_C(PG, PTR, DST, SRC) \ do { \ - DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \ + svtmp16 = svld1_bf16((pghalf), (PTR)); \ + DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \ DST = svmad_z((PG), svalpha, SRC, DST); \ svtmp16 = svcvt_bf16_f32_z((PG), DST); \ svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \ @@ -133,6 +135,7 @@ static int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, OUTPUT_FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; svfloat32_t tmp0, tmp1, tmp2, tmp3; #ifdef BGEMM + svbfloat16_t zeros = svdup_n_bf16(TO16(0.0)); svbfloat16_t svtmp16; #else float32x2_t tmp4, tmp5, tmp6, tmp7; diff --git a/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c b/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c index 61889ca7a9..d4e0a38afe 100644 --- a/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c +++ b/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c @@ -51,7 +51,8 @@ #ifdef ALPHA_ONE #define UPDATE_C(PG16, PG32, PTR, SRC) \ do { \ - tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \ + tmp16 = svld1_bf16((PG16), (PTR)); \ + tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \ tmp32 = svadd_z((PG32), SRC, tmp32); \ tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \ tmp16 = svuzp1_bf16(tmp16, tmp16); \ @@ -60,7 +61,8 @@ #else #define UPDATE_C(PG16, PG32, PTR, SRC) \ do { \ - tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \ + tmp16 = svld1_bf16((PG16), (PTR)); \ + tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \ tmp32 = svmad_z((PG32), svalpha, SRC, tmp32); \ tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \ tmp16 = svuzp1_bf16(tmp16, tmp16); \ @@ -121,6 +123,7 @@ static int gemm_kernel_neoversen2_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOA #ifdef BGEMM svbool_t pg16_first_2 = svdupq_b16(1, 1, 0, 0, 0, 0, 0, 0); svbool_t pg16_first_1 = svdupq_b16(1, 0, 0, 0, 0, 0, 0, 0); + svbfloat16_t zeros = svdup_n_bf16(vcvth_bf16_f32(0.0)); #endif bfloat16_t *ptr_a = (bfloat16_t *)A; diff --git a/kernel/generic/gemv_t.c b/kernel/generic/gemv_t.c index 3b651b5c1b..ecf8ebbad5 100644 --- a/kernel/generic/gemv_t.c +++ b/kernel/generic/gemv_t.c @@ -52,7 +52,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]); ix += inc_x; } - y[iy] += F32TOBF16(ALPHA * temp); + if (BETA == ZERO) + { + y[iy] = F32TOBF16(ALPHA * temp); + } + else + { + y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy])); + } iy += inc_y; a_ptr += lda; } diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c index be7c538b60..1fe4501e37 100644 --- a/test/compare_sgemm_bgemm.c +++ b/test/compare_sgemm_bgemm.c @@ -44,7 +44,7 @@ main (int argc, char *argv[]) int ret = 0; int loop = BGEMM_LARGEST; char transA = 'N', transB = 'N'; - float alpha = 1.0, beta = 0.0; + float alpha = 1.0, beta = 1.0; bfloat16 alpha_bf16; sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one); bfloat16 beta_bf16; @@ -94,9 +94,15 @@ main (int argc, char *argv[]) transB = 'T'; } - memset(CC, 0, m * n * sizeof(bfloat16)); - memset(DD, 0, m * n * sizeof(FLOAT)); - memset(C, 0, m * n * sizeof(FLOAT)); + for (j = 0; j < m; j++) + { + for (i = 0; i < n; i++) + { + C[j * n + i] = 100.0; + DD[j * n + i] = 100.0; + sbstobf16_(&one, &C[j * n + i], &one, &CC[j * n + i], &one); + } + } SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, &m, B, &k, &beta, C, &m); @@ -152,7 +158,8 @@ main (int argc, char *argv[]) } if (ret != 0) { - fprintf (stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret); + fprintf(stderr, "BGEMM FAILURES: %d\n", ret); + return 1; } return ret; diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index 4892225168..e7a145f2d6 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -140,7 +140,8 @@ main (int argc, char *argv[]) } if (ret != 0) { - fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret); + fprintf(stderr, "SBGEMM FAILURES: %d\n", ret); + return 1; } return ret; diff --git a/test/compare_sgemv_bgemv.c b/test/compare_sgemv_bgemv.c index 014c7da50d..d9dc30d9a6 100644 --- a/test/compare_sgemv_bgemv.c +++ b/test/compare_sgemv_bgemv.c @@ -147,7 +147,10 @@ int main(int argc, char *argv[]) } // alpha } // beta - if (ret != 0) - fprintf(stderr, "FATAL ERROR BGEMV - Return code: %d\n", ret); + if (ret != 0) { + fprintf(stderr, "BGEMV FAILURES: %d\n", ret); + return 1; + } + return ret; } diff --git a/test/compare_sgemv_sbgemv.c b/test/compare_sgemv_sbgemv.c index 15cdce6cb5..627cf7146c 100644 --- a/test/compare_sgemv_sbgemv.c +++ b/test/compare_sgemv_sbgemv.c @@ -122,7 +122,10 @@ main (int argc, char *argv[]) } // alpha } // beta - if (ret != 0) - fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret); + if (ret != 0) { + fprintf(stderr, "SBGEMV FAILURES: %d\n", ret); + return 1; + } + return ret; }