diff --git a/.gitignore b/.gitignore index dcbc73dc85..a6ee824681 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ test/SBLAT3_3M.SUMM test/ZBLAT2.SUMM test/ZBLAT3.SUMM test/ZBLAT3_3M.SUMM +test/SHBLAT2.SUMM test/SHBLAT3.SUMM test/SBBLAT2.SUMM test/SBBLAT3.SUMM @@ -98,6 +99,7 @@ test/sblat2 test/sblat3 test/sblat3_3m test/test_shgemm +test/test_shgemv test/test_sbgemm test/test_sbgemv test/test_bgemm diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 6d752ac513..26f94bc0b5 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -175,6 +175,10 @@ if (BUILD_BFLOAT16) SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c) SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c) endif () +if (BUILD_HFLOAT16) + SetFallback(SHGEMVNKERNEL ../generic/gemv_n.c) + SetFallback(SHGEMVTKERNEL ../generic/gemv_t.c) +endif () endmacro () macro(SetDefaultL2) @@ -226,6 +230,8 @@ macro(SetDefaultL2) if (BUILD_BFLOAT16) SetFallback(BGEMVNKERNEL ../generic/gemv_n.c) SetFallback(BGEMVTKERNEL ../generic/gemv_t.c) + SetFallback(SHGEMVNKERNEL ../generic/gemv_n.c) + SetFallback(SHGEMVTKERNEL ../generic/gemv_t.c) SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c) SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c) SetFallback(SHGERKERNEL ../generic/ger.c) @@ -260,5 +266,16 @@ if (BUILD_BFLOAT16) SetFallback(SBGEMMONCOPYOBJ sbgemm_oncopy.o) SetFallback(SBGEMMOTCOPYOBJ sbgemm_otcopy.o) endif () - +if (BUILD_HFLOAT16) + SetFallback(SHGEMMKERNEL ../generic/gemmkernel_2x2.c) + SetFallback(SHGEMM_BETA ../generic/gemm_beta.c) + SetFallback(SHGEMMINCOPY ../generic/gemm_ncopy_2.c) + SetFallback(SHGEMMITCOPY ../generic/gemm_tcopy_2.c) + SetFallback(SHGEMMONCOPY ../generic/gemm_ncopy_2.c) + SetFallback(SHGEMMOTCOPY ../generic/gemm_tcopy_2.c) + SetFallback(SHGEMMINCOPYOBJ shgemm_incopy.o) + SetFallback(SHGEMMITCOPYOBJ shgemm_itcopy.o) + SetFallback(SHGEMMONCOPYOBJ shgemm_oncopy.o) + SetFallback(SHGEMMOTCOPYOBJ shgemm_otcopy.o) +endif () endmacro () diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 35843a3265..e717233c1f 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -375,9 +375,12 @@ function(GenerateNamedObjects sources_in) if (NOT no_float_type) string(SUBSTRING ${float_type} 0 1 float_char) string(TOLOWER ${float_char} float_char) - if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM") - set (float_char "sb") - endif () + if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM") + set (float_char "sb") + endif () + if (${float_type} STREQUAL "HFLOAT16" AND NOT "${defines_in}" MATCHES "HGEM") + set (float_char "sh") + endif () endif () if (NOT name_in) diff --git a/common_interface.h b/common_interface.h index 945b6c8a1a..380ce8d081 100644 --- a/common_interface.h +++ b/common_interface.h @@ -261,6 +261,8 @@ void BLASFUNC(bgemv)(char *, blasint *, blasint *, bfloat16 *, bfloat16 *, blas bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *); void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); +void BLASFUNC(shgemv)(char *, blasint *, blasint *, float *, hfloat16 *, blasint *, + hfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *, float *, blasint *, float *, float *, blasint *); void BLASFUNC(dgemv)(char *, blasint *, blasint *, double *, double *, blasint *, diff --git a/common_level2.h b/common_level2.h index eea5e43f3c..492787cf70 100644 --- a/common_level2.h +++ b/common_level2.h @@ -54,6 +54,10 @@ int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLO int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int); int sbgemv_thread_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int); +int shgemv_n(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG); +int shgemv_t(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG); +int shgemv_thread_n(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG, int); +int shgemv_thread_t(BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG, int); int sger_k (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); int dger_k (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG, double *); int qger_k (BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *); diff --git a/common_macro.h b/common_macro.h index f9c22089b3..745643fa89 100644 --- a/common_macro.h +++ b/common_macro.h @@ -703,6 +703,9 @@ #define GEMM_THREAD_RC SHGEMM_THREAD_NT #define GEMM_THREAD_RR SHGEMM_THREAD_NN +#define SCAL_K SSCAL_K +#define GEMV_N SHGEMV_N_K +#define GEMV_T SHGEMV_T_K #elif defined(BFLOAT16) && defined(BGEMM) #define SCAL_K BSCAL_K diff --git a/common_param.h b/common_param.h index 54bb896fc5..efd5912147 100644 --- a/common_param.h +++ b/common_param.h @@ -60,7 +60,8 @@ int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); - +int (*shgemv_n) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG); +int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float, float *, BLASLONG); #endif diff --git a/common_sh.h b/common_sh.h index 69734d1dc2..99dbb65180 100644 --- a/common_sh.h +++ b/common_sh.h @@ -1,3 +1,31 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + #ifndef COMMON_SH_H #define COMMON_SH_H @@ -17,6 +45,9 @@ #define SHGEMM_BETA shgemm_beta #define SHGEMM_KERNEL shgemm_kernel +#define SHGEMV_N_K shgemv_n +#define SHGEMV_T_K shgemv_t + #else // #DYNAMIC_ARCH @@ -32,6 +63,10 @@ #define SHGEMM_BETA gotoblas -> shgemm_beta #define SHGEMM_KERNEL gotoblas -> shgemm_kernel + +#define SHGEMV_N_K gotoblas->shgemv_n +#define SHGEMV_T_K gotoblas->shgemv_t + #endif // #DYNAMIC_ARCH #define SHGEMM_NN shgemm_nn diff --git a/driver/level2/Makefile b/driver/level2/Makefile index 3f3731d3f9..d50e70bcd1 100644 --- a/driver/level2/Makefile +++ b/driver/level2/Makefile @@ -450,6 +450,12 @@ XBLASOBJS += \ xtbmv_thread_CUU.$(SUFFIX) xtbmv_thread_CUN.$(SUFFIX) \ xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX) + +ifeq ($(BUILD_HFLOAT16),1) +SHBLASOBJS += \ + shgemv_thread_n$(TSUFFIX).$(SUFFIX) \ + shgemv_thread_t$(TSUFFIX).$(SUFFIX) +endif ifeq ($(BUILD_BFLOAT16),1) BBLASOBJS += \ bgemv_thread_n$(TSUFFIX).$(SUFFIX) \ @@ -3737,6 +3743,13 @@ xtrsv_CUU.$(SUFFIX) xtrsv_CUU.$(PSUFFIX) : ztrsv_L.c ../../param.h xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h $(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F) +ifeq ($(BUILD_HFLOAT16),1) +shgemv_thread_n.$(SUFFIX) shgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h + $(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F) +shgemv_thread_t.$(SUFFIX) shgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h + $(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -DTRANSA -UCONJ -UXCONJ $< -o $(@F) +endif + ifeq ($(BUILD_BFLOAT16),1) bgemv_thread_n.$(SUFFIX) bgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h $(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F) diff --git a/exports/gensymbol b/exports/gensymbol index 40e13e623f..01c930ea97 100755 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -80,7 +80,7 @@ blasobjsz=" blasobjs="lsame xerbla" bfblasobjs="bgemm bgemv sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod" -hfblasobjs="shgemm" +hfblasobjs="shgemm shgemv" cblasobjsc=" cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv cblas_cgerc cblas_cgeru cblas_chbmv cblas_chemm cblas_chemv cblas_cher2 cblas_cher2k diff --git a/exports/gensymbol.pl b/exports/gensymbol.pl index 3447a4e515..bdfa69b27b 100644 --- a/exports/gensymbol.pl +++ b/exports/gensymbol.pl @@ -80,7 +80,7 @@ @blasobjs = (lsame, xerbla); @bfblasobjs = (bgemm, bgemv, sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); -@hfblasobjs = (shgemm); +@hfblasobjs = (shgemm, shgemv); @cblasobjsc = ( cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index e8e0a88969..ee7d40d382 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -166,6 +166,7 @@ if (BUILD_BFLOAT16) endif () if (BUILD_HFLOAT16) GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16") + GenerateNamedObjects("sbgemv.c" "" "shgemv" ${CBLAS_FLAG} "" "" true "HFLOAT16") endif () # complex-specific sources diff --git a/interface/Makefile b/interface/Makefile index abc6d053de..83a894b125 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -87,6 +87,7 @@ endif ifeq ($(BUILD_HFLOAT16),1) SHBLAS3OBJS = shgemm.$(SUFFIX) +SHBLAS2OBJS = shgemv.$(SUFFIX) endif DBLAS1OBJS = \ @@ -338,6 +339,7 @@ endif ifeq ($(BUILD_HFLOAT16),1) CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) +CSHBLAS2OBJS = cblas_shgemv.$(SUFFIX) endif CDBLAS1OBJS = \ @@ -441,6 +443,7 @@ SBBLAS1OBJS += $(CSBBLAS1OBJS) SBBLAS2OBJS += $(CSBBLAS2OBJS) SBBLAS3OBJS += $(CSBBLAS3OBJS) SHBLAS3OBJS += $(CSHBLAS3OBJS) +SHBLAS2OBJS += $(CSHBLAS2OBJS) DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS3OBJS += $(CDBLAS3OBJS) @@ -459,7 +462,7 @@ endif BBLASOBJS = $(BBLAS3OBJS) $(BBLAS2OBJS) $(BBLAS1OBJS) SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) -SHBLASOBJS = $(SHBLAS3OBJS) +SHBLASOBJS = $(SHBLAS3OBJS) $(SHBLAS2OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) @@ -602,7 +605,7 @@ clean :: level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level2 : $(SBBLAS2OBJS) $(BBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) +level2 : $(SBBLAS2OBJS) $(BBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(SHBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ level3 : $(SBBLAS3OBJS) $(BBLAS3OBJ) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS) @@ -1002,6 +1005,11 @@ sbgemv.$(SUFFIX) sbgemv.$(PSUFFIX) : sbgemv.c $(CC) $(CFLAGS) -c $< -o $(@F) endif +ifeq ($(BUILD_HFLOAT16),1) +shgemv.$(SUFFIX) shgemv.$(PSUFFIX) : sbgemv.c + $(CC) $(CFLAGS) -c $< -o $(@F) +endif + ifndef USE_NETLIB_GEMV sgemv.$(SUFFIX) sgemv.$(PSUFFIX): gemv.c $(CC) -c $(CFLAGS) -o $(@F) $< @@ -1832,6 +1840,11 @@ cblas_sbgemv.$(SUFFIX) cblas_sbgemv.$(PSUFFIX) : sbgemv.c $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) endif +ifeq ($(BUILD_HFLOAT16),1) +cblas_shgemv.$(SUFFIX) cblas_shgemv.$(PSUFFIX) : sbgemv.c + $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) +endif + cblas_sgemv.$(SUFFIX) cblas_sgemv.$(PSUFFIX): gemv.c $(CC) -DCBLAS -c $(CFLAGS) -o $(@F) $< diff --git a/interface/gemm.c b/interface/gemm.c index c5182c266a..28f962ad87 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -587,7 +587,10 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS args.m, args.n, args.k, args.lda, args.ldb, args.ldc); #endif -#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(HFLOAT16) && (!defined(BFLOAT16) || (!defined(BGEMM) && defined(SBGEMM_GEMV_FORWARD)) || (defined(BGEMM) && defined(BGEMM_GEMV_FORWARD))) +#define BFLOAT16_GEMM_GEMV_FORWARD (!defined(BFLOAT16) || (!defined(BGEMM) && defined(SBGEMM_GEMV_FORWARD)) || (defined(BGEMM) && defined(BGEMM_GEMV_FORWARD))) +#define HFLOAT16_GEMM_GEMV_FORWARD (!defined(HFLOAT16) || (!defined(HGEMM) && defined(SHGEMM_GEMV_FORWARD)) || (defined(HGEMM) && defined(HGEMM_GEMV_FORWARD))) + +#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && HFLOAT16_GEMM_GEMV_FORWARD && BFLOAT16_GEMM_GEMV_FORWARD #if defined(ARCH_ARM64) // The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c} // perform poorly in certain circumstances. We use the following boolean diff --git a/interface/sbgemv.c b/interface/sbgemv.c index cee3e80fcf..12db2dfb1c 100644 --- a/interface/sbgemv.c +++ b/interface/sbgemv.c @@ -48,6 +48,10 @@ #define GEMV_THREAD_N bgemv_thread_n #define GEMV_THREAD_T bgemv_thread_t #define ERROR_NAME "BGEMV " +#elif defined(HFLOAT16) +#define GEMV_THREAD_N shgemv_thread_n +#define GEMV_THREAD_T shgemv_thread_t +#define ERROR_NAME "SHGEMV " #else #define GEMV_THREAD_N sbgemv_thread_n #define GEMV_THREAD_T sbgemv_thread_t diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index d73dc27e29..ffae1973bb 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -228,6 +228,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${SBGEMVNKERNEL}" "" "gemv_n" false "" "" false "BFLOAT16") GenerateNamedObjects("${KERNELDIR}/${SBGEMVTKERNEL}" "" "gemv_t" false "" "" false "BFLOAT16") endif () + if (BUILD_HFLOAT16) + GenerateNamedObjects("${KERNELDIR}/${SHGEMVNKERNEL}" "" "gemv_n" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMVTKERNEL}" "" "gemv_t" false "" "" false "HFLOAT16") + endif () # Makefile.L3 set(USE_TRMM false) string(TOUPPER ${TARGET_CORE} UC_TARGET_CORE) diff --git a/kernel/Makefile.L2 b/kernel/Makefile.L2 index a9fcf92250..aea0c9cbb4 100644 --- a/kernel/Makefile.L2 +++ b/kernel/Makefile.L2 @@ -101,6 +101,16 @@ SBGEMVTKERNEL = ../x86_64/sbgemv_t.c endif endif +ifeq ($(BUILD_HFLOAT16),1) +ifndef SHGEMVNKERNEL +SHGEMVNKERNEL = ../generic/gemv_n.c +endif + +ifndef SHGEMVTKERNEL +SHGEMVTKERNEL = ../generic/gemv_t.c +endif +endif + ### GER ### ifndef SGERKERNEL @@ -299,6 +309,12 @@ SBBLASOBJS += \ sbgemv_t$(TSUFFIX).$(SUFFIX) endif +ifeq ($(BUILD_HFLOAT16),1) +SHBLASOBJS += \ + shgemv_n$(TSUFFIX).$(SUFFIX) \ + shgemv_t$(TSUFFIX).$(SUFFIX) +endif + ifneq "$(or $(BUILD_SINGLE), $(BUILD_DOUBLE), $(BUILD_COMPLEX))" "" $(KDIR)sgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sgemv_n$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMVNKERNEL) $(TOPDIR)/common.h $(GEMVDEP) $(CC) -c $(CFLAGS) -UDOUBLE -UCOMPLEX -UTRANS $< -o $@ @@ -558,3 +574,10 @@ $(KDIR)bgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)bgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERN $(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX $< -o $@ endif +ifeq ($(BUILD_HFLOAT16),1) +$(KDIR)shgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)shgemv_n$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMVNKERNEL) + $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ +$(KDIR)shgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)shgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMVTKERNEL) + $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ +endif + diff --git a/kernel/generic/bf16_macros.h b/kernel/generic/conversion_macros.h similarity index 83% rename from kernel/generic/bf16_macros.h rename to kernel/generic/conversion_macros.h index f1b02cea4a..69f8520128 100644 --- a/kernel/generic/bf16_macros.h +++ b/kernel/generic/conversion_macros.h @@ -27,6 +27,7 @@ * *****************************************************************************/ #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) + static float bfloat16tof32 (bfloat16 value) { @@ -48,17 +49,34 @@ static bfloat16 f32tobfloat16(float value) { #ifdef BGEMM #define ALPHA bfloat16tof32(alpha) #define BETA bfloat16tof32(beta) -#define BF16TOF32(x) (bfloat16tof32(x)) -#define F32TOBF16(x) (f32tobfloat16(x)) +#define TO_F32(x) (bfloat16tof32(x)) +#define TO_OUTPUT(x) (f32tobfloat16(x)) +#else +#define ALPHA alpha +#define BETA beta +#define TO_F32(x) (bfloat16tof32(x)) +#define TO_OUTPUT(x) x +#endif + +#elif defined(HFLOAT16) + +#ifdef HGEMM +#define ALPHA (float)(alpha) +#define BETA (float)(beta) +#define TO_F32(x) ((float)(x)) +#define TO_OUTPUT(x) ((_Float16)(x)) #else #define ALPHA alpha #define BETA beta -#define BF16TOF32(x) (bfloat16tof32(x)) -#define F32TOBF16(x) x +#define TO_F32(x) ((float)(x)) +#define TO_OUTPUT(x) x #endif + #else + #define ALPHA alpha #define BETA beta -#define BF16TOF32(x) x -#define F32TOBF16(x) x +#define TO_F32(x) x +#define TO_OUTPUT(x) x + #endif diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index c24370c890..07da2cbc87 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -27,7 +27,8 @@ * *****************************************************************************/ #include "common.h" -#include "bf16_macros.h" + +#include "conversion_macros.h" int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL @@ -60,36 +61,36 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, { load0 = ptrba[2*0+0]; load1 = ptrbb[2*0+0]; - res0 = res0+BF16TOF32(load0)*BF16TOF32(load1); + res0 = res0+TO_F32(load0)*TO_F32(load1); load2 = ptrba[2*0+1]; - res1 = res1+BF16TOF32(load2)*BF16TOF32(load1); + res1 = res1+TO_F32(load2)*TO_F32(load1); load3 = ptrbb[2*0+1]; - res2 = res2+BF16TOF32(load0)*BF16TOF32(load3); - res3 = res3+BF16TOF32(load2)*BF16TOF32(load3); + res2 = res2+TO_F32(load0)*TO_F32(load3); + res3 = res3+TO_F32(load2)*TO_F32(load3); load4 = ptrba[2*1+0]; load5 = ptrbb[2*1+0]; - res0 = res0+BF16TOF32(load4)*BF16TOF32(load5); + res0 = res0+TO_F32(load4)*TO_F32(load5); load6 = ptrba[2*1+1]; - res1 = res1+BF16TOF32(load6)*BF16TOF32(load5); + res1 = res1+TO_F32(load6)*TO_F32(load5); load7 = ptrbb[2*1+1]; - res2 = res2+BF16TOF32(load4)*BF16TOF32(load7); - res3 = res3+BF16TOF32(load6)*BF16TOF32(load7); + res2 = res2+TO_F32(load4)*TO_F32(load7); + res3 = res3+TO_F32(load6)*TO_F32(load7); load0 = ptrba[2*2+0]; load1 = ptrbb[2*2+0]; - res0 = res0+BF16TOF32(load0)*BF16TOF32(load1); + res0 = res0+TO_F32(load0)*TO_F32(load1); load2 = ptrba[2*2+1]; - res1 = res1+BF16TOF32(load2)*BF16TOF32(load1); + res1 = res1+TO_F32(load2)*TO_F32(load1); load3 = ptrbb[2*2+1]; - res2 = res2+BF16TOF32(load0)*BF16TOF32(load3); - res3 = res3+BF16TOF32(load2)*BF16TOF32(load3); + res2 = res2+TO_F32(load0)*TO_F32(load3); + res3 = res3+TO_F32(load2)*TO_F32(load3); load4 = ptrba[2*3+0]; load5 = ptrbb[2*3+0]; - res0 = res0+BF16TOF32(load4)*BF16TOF32(load5); + res0 = res0+TO_F32(load4)*TO_F32(load5); load6 = ptrba[2*3+1]; - res1 = res1+BF16TOF32(load6)*BF16TOF32(load5); + res1 = res1+TO_F32(load6)*TO_F32(load5); load7 = ptrbb[2*3+1]; - res2 = res2+BF16TOF32(load4)*BF16TOF32(load7); - res3 = res3+BF16TOF32(load6)*BF16TOF32(load7); + res2 = res2+TO_F32(load4)*TO_F32(load7); + res3 = res3+TO_F32(load6)*TO_F32(load7); ptrba = ptrba+8; ptrbb = ptrbb+8; } @@ -97,23 +98,23 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, { load0 = ptrba[2*0+0]; load1 = ptrbb[2*0+0]; - res0 = res0+BF16TOF32(load0)*BF16TOF32(load1); + res0 = res0+TO_F32(load0)*TO_F32(load1); load2 = ptrba[2*0+1]; - res1 = res1+BF16TOF32(load2)*BF16TOF32(load1); + res1 = res1+TO_F32(load2)*TO_F32(load1); load3 = ptrbb[2*0+1]; - res2 = res2+BF16TOF32(load0)*BF16TOF32(load3); - res3 = res3+BF16TOF32(load2)*BF16TOF32(load3); + res2 = res2+TO_F32(load0)*TO_F32(load3); + res3 = res3+TO_F32(load2)*TO_F32(load3); ptrba = ptrba+2; ptrbb = ptrbb+2; } res0 = res0*ALPHA; - C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); + C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1); res2 = res2*ALPHA; - C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2); + C1[0] = TO_OUTPUT(TO_F32(C1[0])+res2); res3 = res3*ALPHA; - C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3); + C1[1] = TO_OUTPUT(TO_F32(C1[1])+res3); C0 = C0+2; C1 = C1+2; } @@ -126,16 +127,16 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, { load0 = ptrba[0+0]; load1 = ptrbb[2*0+0]; - res0 = res0+BF16TOF32(load0)*BF16TOF32(load1); + res0 = res0+TO_F32(load0)*TO_F32(load1); load2 = ptrbb[2*0+1]; - res1 = res1+BF16TOF32(load0)*BF16TOF32(load2); + res1 = res1+TO_F32(load0)*TO_F32(load2); ptrba = ptrba+1; ptrbb = ptrbb+2; } res0 = res0*ALPHA; - C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C1[0] = F32TOBF16(BF16TOF32(C1[0])+res1); + C1[0] = TO_OUTPUT(TO_F32(C1[0])+res1); C0 = C0+1; C1 = C1+1; } @@ -157,16 +158,16 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, { load0 = ptrba[2*0+0]; load1 = ptrbb[0+0]; - res0 = res0+BF16TOF32(load0)*BF16TOF32(load1); + res0 = res0+TO_F32(load0)*TO_F32(load1); load2 = ptrba[2*0+1]; - res1 = res1+BF16TOF32(load2)*BF16TOF32(load1); + res1 = res1+TO_F32(load2)*TO_F32(load1); ptrba = ptrba+2; ptrbb = ptrbb+1; } res0 = res0*ALPHA; - C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); res1 = res1*ALPHA; - C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); + C0[1] = TO_OUTPUT(TO_F32(C0[1])+res1); C0 = C0+2; } for (i=0; i<(bm&1); i+=1) @@ -177,12 +178,12 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, { load0 = ptrba[0+0]; load1 = ptrbb[0+0]; - res0 = res0+BF16TOF32(load0)*BF16TOF32(load1); + res0 = res0+TO_F32(load0)*TO_F32(load1); ptrba = ptrba+1; ptrbb = ptrbb+1; } res0 = res0*ALPHA; - C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + C0[0] = TO_OUTPUT(TO_F32(C0[0])+res0); C0 = C0+1; } k = (bk<<0); diff --git a/kernel/generic/gemv_n.c b/kernel/generic/gemv_n.c index 1c72b07af5..b7e8950c82 100644 --- a/kernel/generic/gemv_n.c +++ b/kernel/generic/gemv_n.c @@ -26,15 +26,14 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ #include "common.h" -#include "bf16_macros.h" + +#include "conversion_macros.h" int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) { - BLASLONG i; BLASLONG ix, iy; - BLASLONG j; - FLOAT *a_ptr; -#ifdef BGEMM + IFLOAT *a_ptr; +#if defined(BGEMM) || defined(HGEMM) float temp; #else FLOAT temp; @@ -49,18 +48,18 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * a_ptr = a; for (BLASLONG j = 0; j < n; j++) { - temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]); + temp += TO_F32(a_ptr[i]) * TO_F32(x[ix]); ix += inc_x; a_ptr += lda; } if (BETA == ZERO) { - y[iy] = F32TOBF16(ALPHA * temp); + y[iy] = TO_OUTPUT(ALPHA * temp); } else { - y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy])); + y[iy] = TO_OUTPUT(ALPHA * temp + BETA * TO_F32(y[iy])); } iy += inc_y; diff --git a/kernel/generic/gemv_t.c b/kernel/generic/gemv_t.c index ecf8ebbad5..5124fafde1 100644 --- a/kernel/generic/gemv_t.c +++ b/kernel/generic/gemv_t.c @@ -26,15 +26,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ #include "common.h" -#include "bf16_macros.h" + +#include "conversion_macros.h" int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) { BLASLONG i; BLASLONG ix, iy; BLASLONG j; - FLOAT *a_ptr; -#ifdef BGEMM + IFLOAT *a_ptr; +#if defined(BGEMM) || defined(HGEMM) float temp; #else FLOAT temp; @@ -49,16 +50,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * ix = 0; for (i = 0; i < m; i++) { - temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]); + temp += TO_F32(a_ptr[i]) * TO_F32(x[ix]); ix += inc_x; } if (BETA == ZERO) { - y[iy] = F32TOBF16(ALPHA * temp); + y[iy] = TO_OUTPUT(ALPHA * temp); } else { - y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy])); + y[iy] = TO_OUTPUT(ALPHA * temp + BETA * TO_F32(y[iy])); } iy += inc_y; a_ptr += lda; diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index df455cd5d8..bf114981c4 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -56,6 +56,24 @@ gotoblas_t TABLE_NAME = { GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN, +#ifdef BUILD_HFLOAT16 + 0, 0, 0, + SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N, +#ifdef SHGEMM_DEFAULT_UNROLL_MN + SHGEMM_DEFAULT_UNROLL_MN, +#else + MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), +#endif + shgemm_kernelTS, shgemm_betaTS, +#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N + shgemm_incopyTS, shgemm_itcopyTS, +#else + shgemm_oncopyTS, shgemm_otcopyTS, +#endif + shgemm_oncopyTS, shgemm_otcopyTS, + shgemv_nTS, shgemv_tTS, +#endif + #ifdef BUILD_BFLOAT16 0, 0, 0, BGEMM_DEFAULT_UNROLL_M, BGEMM_DEFAULT_UNROLL_N, @@ -142,23 +160,6 @@ gotoblas_t TABLE_NAME = { #endif #endif -#ifdef BUILD_HFLOAT16 - 0, 0, 0, - SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N, -#ifdef SHGEMM_DEFAULT_UNROLL_MN - SHGEMM_DEFAULT_UNROLL_MN, -#else - MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), -#endif - shgemm_kernelTS, shgemm_betaTS, -#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N - shgemm_incopyTS, shgemm_itcopyTS, -#else - shgemm_oncopyTS, shgemm_otcopyTS, -#endif - shgemm_oncopyTS, shgemm_otcopyTS, -#endif - #if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1) 0, 0, 0, SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N, diff --git a/test/Makefile b/test/Makefile index 58b6710c6b..f29bd35471 100644 --- a/test/Makefile +++ b/test/Makefile @@ -119,6 +119,9 @@ endif endif endif +ifeq ($(BUILD_HFLOAT16), 1) +SH2 = test_shgemv +endif ifeq ($(BUILD_BFLOAT16), 1) BB2 = test_bgemv B2 = test_sbgemv @@ -136,7 +139,7 @@ ifeq ($(BUILD_COMPLEX16),1) Z2=zblat2 endif -level2: $(BB2) $(B2) $(S2) $(D2) $(C2) $(Z2) +level2: $(SH2) $(BB2) $(B2) $(S2) $(D2) $(C2) $(Z2) ifneq ($(CROSS), 1) @@ -147,6 +150,10 @@ ifeq ($(BUILD_BFLOAT16),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemv > SBBLAT2.SUMM @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 endif +ifeq ($(BUILD_HFLOAT16),1) + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_shgemv > SHBLAT2.SUMM + @$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat2 < ./sblat2.dat @$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0 @@ -172,6 +179,10 @@ ifeq ($(BUILD_BFLOAT16),1) OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 endif +ifeq ($(BUILD_HFLOAT16),1) + OMP_NUM_THREADS=2 ./test_shgemv > SHBLAT2.SUMM + @$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OMP_NUM_THREADS=2 ./sblat2 < ./sblat2.dat @$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0 @@ -195,6 +206,10 @@ ifeq ($(BUILD_BFLOAT16),1) OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 endif +ifeq ($(BUILD_HFLOAT16),1) + OMP_NUM_THREADS=2 ./test_shgemv > SHBLAT2.SUMM + @$(GREP) -q FATAL SHBLAT2.SUMM && cat SHBLAT2.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=2 ./sblat2 < ./sblat2.dat @$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0 @@ -438,6 +453,12 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME) $(CC) $(CLDFLAGS) -DIBFLOAT16 -o test_sbgemv compare_sgemv_sbgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif +ifeq ($(BUILD_HFLOAT16),1) +test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME) + $(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) +endif + + ifeq ($(BUILD_COMPLEX),1) cblat3_3m : cblat3_3m.$(SUFFIX) ../$(LIBNAME) $(FC) $(FLDFLAGS) -o cblat3_3m cblat3_3m.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) @@ -454,7 +475,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_bgemm test_bgemv test_sbgemm test_sbgemv sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemv_shgemv.c b/test/compare_sgemv_shgemv.c new file mode 100644 index 0000000000..9e92218acb --- /dev/null +++ b/test/compare_sgemv_shgemv.c @@ -0,0 +1,130 @@ +/*************************************************************************** +Copyright (c) 2020,2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include +#include +#include "../common.h" + +#include "test_helpers.h" + +#define SGEMV BLASFUNC(sgemv) +#define SHGEMV BLASFUNC(shgemv) +#define SHGEMV_LARGEST 256 + +int +main (int argc, char *argv[]) +{ + blasint k; + int i, j, l; + blasint x, y; + int ret = 0; + int loop = SHGEMV_LARGEST; + char transA = 'N'; + float alpha = 1.0, beta = 0.0; + + for (beta = 0; beta < 3; beta += 1) { + for (alpha = 0; alpha < 3; alpha += 1) { + for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. + for (x = 1; x <= loop; x++) + { + k = (x == 0) ? 0 : l + 1; + float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); + hfloat16 *AA = (hfloat16 *)malloc_safe(x * x * sizeof(hfloat16)); + hfloat16 *BB = (hfloat16 *)malloc_safe(x * sizeof(hfloat16) << l); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (DD == NULL) || (CC == NULL)) + return 1; + + for (j = 0; j < x; j++) + { + for (i = 0; i < x; i++) + { + A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + AA[j * x + i] = (_Float16)A[j * x + i]; + } + B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + BB[j << l]= (_Float16)B[j << l]; + + CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + } + + for (y = 0; y < 2; y++) + { + if (y == 0) { + transA = 'N'; + } else { + transA = 'T'; + } + + memset(CC, 0, x * sizeof(FLOAT) << l); + memset(DD, 0, x * sizeof(FLOAT)); + memset(C, 0, x * sizeof(FLOAT) << l); + + SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); + SHGEMV (&transA, &x, &x, &alpha, (hfloat16*) AA, &x, (hfloat16*) BB, &k, &beta, CC, &k); + + for (int i = 0; i < x; i ++) DD[i] *= beta; + + for (j = 0; j < x; j++) + for (i = 0; i < x; i++) + if (transA == 'N') { + DD[i] += alpha * (float)(AA[j * x + i]) * (float)(BB[j << l]); + } else if (transA == 'T') { + DD[j] += alpha * (float)(AA[j * x + i]) * (float)(BB[i << l]); + } + + for (j = 0; j < x; j++) { + if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) { + ret++; + } + if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) { + ret++; + } + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(DD); + free(CC); + } // x + } // l + } // alpha + } // beta + + if (ret != 0) { + fprintf (stderr, "SHGEMV FAILURES: %d\n", ret); + return 1; + } + + return ret; +}