Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernel/arm64/KERNEL.NEOVERSEN2
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ ZGEMMONCOPYOBJ = zgemm_oncopy$(TSUFFIX).$(SUFFIX)
ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX)

ifeq ($(BUILD_BFLOAT16), 1)
BGEMM_BETA = sbgemm_beta_neoversen2.c
BGEMM_BETA = bgemm_beta_neon.c
BGEMMKERNEL = sbgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversen2.c
BGEMMINCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_M)_neoversen2.c
BGEMMITCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_M)_neoversen2.c
Expand Down
7 changes: 5 additions & 2 deletions kernel/arm64/bgemm_kernel_2vlx4_neoversev1_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@

#define UPDATE_C(PG, PTR, DST, SRC) \
do { \
DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \
svtmp16 = svld1_bf16((pghalf), (PTR)); \
DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \
DST = svadd_z((PG), SRC, DST); \
svtmp16 = svcvt_bf16_f32_z((PG), DST); \
svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \
Expand All @@ -55,7 +56,8 @@

#define UPDATE_C(PG, PTR, DST, SRC) \
do { \
DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \
svtmp16 = svld1_bf16((pghalf), (PTR)); \
DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \
DST = svmad_z((PG), svalpha, SRC, DST); \
svtmp16 = svcvt_bf16_f32_z((PG), DST); \
svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \
Expand Down Expand Up @@ -133,6 +135,7 @@ static int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k,
OUTPUT_FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3;
svfloat32_t tmp0, tmp1, tmp2, tmp3;
#ifdef BGEMM
svbfloat16_t zeros = svdup_n_bf16(TO16(0.0));
svbfloat16_t svtmp16;
#else
float32x2_t tmp4, tmp5, tmp6, tmp7;
Expand Down
7 changes: 5 additions & 2 deletions kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
#ifdef ALPHA_ONE
#define UPDATE_C(PG16, PG32, PTR, SRC) \
do { \
tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
tmp16 = svld1_bf16((PG16), (PTR)); \
tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \
tmp32 = svadd_z((PG32), SRC, tmp32); \
tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \
tmp16 = svuzp1_bf16(tmp16, tmp16); \
Expand All @@ -60,7 +61,8 @@
#else
#define UPDATE_C(PG16, PG32, PTR, SRC) \
do { \
tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
tmp16 = svld1_bf16((PG16), (PTR)); \
tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \
tmp32 = svmad_z((PG32), svalpha, SRC, tmp32); \
tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \
tmp16 = svuzp1_bf16(tmp16, tmp16); \
Expand Down Expand Up @@ -121,6 +123,7 @@ static int gemm_kernel_neoversen2_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOA
#ifdef BGEMM
svbool_t pg16_first_2 = svdupq_b16(1, 1, 0, 0, 0, 0, 0, 0);
svbool_t pg16_first_1 = svdupq_b16(1, 0, 0, 0, 0, 0, 0, 0);
svbfloat16_t zeros = svdup_n_bf16(vcvth_bf16_f32(0.0));
#endif

bfloat16_t *ptr_a = (bfloat16_t *)A;
Expand Down
9 changes: 8 additions & 1 deletion kernel/generic/gemv_t.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]);
ix += inc_x;
}
y[iy] += F32TOBF16(ALPHA * temp);
if (BETA == ZERO)
{
y[iy] = F32TOBF16(ALPHA * temp);
}
else
{
y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy]));
}
iy += inc_y;
a_ptr += lda;
}
Expand Down
17 changes: 12 additions & 5 deletions test/compare_sgemm_bgemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ main (int argc, char *argv[])
int ret = 0;
int loop = BGEMM_LARGEST;
char transA = 'N', transB = 'N';
float alpha = 1.0, beta = 0.0;
float alpha = 1.0, beta = 1.0;
bfloat16 alpha_bf16;
sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one);
bfloat16 beta_bf16;
Expand Down Expand Up @@ -94,9 +94,15 @@ main (int argc, char *argv[])
transB = 'T';
}

memset(CC, 0, m * n * sizeof(bfloat16));
memset(DD, 0, m * n * sizeof(FLOAT));
memset(C, 0, m * n * sizeof(FLOAT));
for (j = 0; j < m; j++)
{
for (i = 0; i < n; i++)
{
C[j * n + i] = 100.0;
DD[j * n + i] = 100.0;
sbstobf16_(&one, &C[j * n + i], &one, &CC[j * n + i], &one);
}
}

SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
&m, B, &k, &beta, C, &m);
Expand Down Expand Up @@ -152,7 +158,8 @@ main (int argc, char *argv[])
}

if (ret != 0) {
fprintf (stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret);
fprintf(stderr, "BGEMM FAILURES: %d\n", ret);
return 1;
}

return ret;
Expand Down
3 changes: 2 additions & 1 deletion test/compare_sgemm_sbgemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ main (int argc, char *argv[])
}

if (ret != 0) {
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
fprintf(stderr, "SBGEMM FAILURES: %d\n", ret);
return 1;
}

return ret;
Expand Down
7 changes: 5 additions & 2 deletions test/compare_sgemv_bgemv.c
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ int main(int argc, char *argv[])
} // alpha
} // beta

if (ret != 0)
fprintf(stderr, "FATAL ERROR BGEMV - Return code: %d\n", ret);
if (ret != 0) {
fprintf(stderr, "BGEMV FAILURES: %d\n", ret);
return 1;
}

return ret;
}
7 changes: 5 additions & 2 deletions test/compare_sgemv_sbgemv.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ main (int argc, char *argv[])
} // alpha
} // beta

if (ret != 0)
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
if (ret != 0) {
fprintf(stderr, "SBGEMV FAILURES: %d\n", ret);
return 1;
}

return ret;
}
Loading