diff --git a/cblas.h b/cblas.h index 7503e43f7a..8395f1b8b2 100644 --- a/cblas.h +++ b/cblas.h @@ -456,6 +456,14 @@ void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enu void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); +void cblas_sgemm_batch_strided(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST float * A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST blasint stridea, OPENBLAS_CONST float * B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST blasint strideb, OPENBLAS_CONST float beta, float * C, OPENBLAS_CONST blasint ldc, OPENBLAS_CONST blasint stridec, OPENBLAS_CONST blasint group_size); + +void cblas_dgemm_batch_strided(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST double alpha, OPENBLAS_CONST double * A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST blasint stridea, OPENBLAS_CONST double * B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST blasint strideb, OPENBLAS_CONST double beta, double * C, OPENBLAS_CONST blasint ldc, OPENBLAS_CONST blasint stridec, OPENBLAS_CONST blasint group_size); + +void cblas_cgemm_batch_strided(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST void * alpha, OPENBLAS_CONST void * A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST blasint stridea, OPENBLAS_CONST void * B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST blasint strideb, OPENBLAS_CONST void * beta, void * C, OPENBLAS_CONST blasint ldc, OPENBLAS_CONST blasint stridec, OPENBLAS_CONST blasint group_size); + +void cblas_zgemm_batch_strided(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST void * alpha, OPENBLAS_CONST void * A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST blasint stridea, OPENBLAS_CONST void * B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST blasint strideb, OPENBLAS_CONST void * beta, void * C, OPENBLAS_CONST blasint ldc, OPENBLAS_CONST blasint stridec, OPENBLAS_CONST blasint group_size); + /*** BFLOAT16 and INT8 extensions ***/ /* convert float array to BFLOAT16 array by rounding */ void cblas_sbstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); @@ -477,6 +485,7 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); +void cblas_sbgemm_batch_strided(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 * A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST blasint stridea, OPENBLAS_CONST bfloat16 * B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST blasint strideb, OPENBLAS_CONST float beta, float * C, OPENBLAS_CONST blasint ldc, OPENBLAS_CONST blasint stridec, OPENBLAS_CONST blasint group_size); /*** FLOAT16 extensions ***/ void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index 25b127d9b2..e8e0a88969 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -125,6 +125,7 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS}) if (BUILD_SINGLE OR BUILD_DOUBLE) GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE") GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false) + GenerateNamedObjects("gemm_batch_strided.c" "" "gemm_batch_strided" ${CBLAS_FLAG} "" "" false) endif () if (BUILD_DOUBLE) GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE") @@ -161,6 +162,7 @@ if (BUILD_BFLOAT16) GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("gemm_batch_strided.c" "" "sbgemm_batch_strided" ${CBLAS_FLAG} "" "" true "BFLOAT16") endif () if (BUILD_HFLOAT16) GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16") @@ -194,6 +196,7 @@ foreach (float_type ${FLOAT_TYPES}) GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX") + GenerateNamedObjects("gemm_batch_strided.c" "" "cgemm_batch_strided" ${CBLAS_FLAG} "" "" true "COMPLEX") endif () if (${float_type} STREQUAL "ZCOMPLEX") GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX") @@ -204,6 +207,7 @@ foreach (float_type ${FLOAT_TYPES}) GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") + GenerateNamedObjects("gemm_batch_strided.c" "" "zgemm_batch_strided" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") endif () endforeach () @@ -255,6 +259,7 @@ if ( BUILD_COMPLEX AND NOT BUILD_SINGLE) GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE") GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 0 "" "" false "SINGLE") + GenerateNamedObjects("gemm_batch_strided.c" "" "gemm_batch_strided" 0 "" "" false "SINGLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE") @@ -269,6 +274,7 @@ if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 0 "" "" false "DOUBLE") + GenerateNamedObjects("gemm_batch_strided.c" "" "gemm_batch_strided" 0 "" "" false "DOUBLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE") diff --git a/interface/Makefile b/interface/Makefile index 9212abe694..abc6d053de 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -73,7 +73,7 @@ SBLAS3OBJS = \ strsm.$(SUFFIX) ssyrk.$(SUFFIX) ssyr2k.$(SUFFIX) \ somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\ sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX) \ - sgemm_batch.$(SUFFIX) + sgemm_batch.$(SUFFIX) sgemm_batch_strided.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) BBLAS3OBJS = bgemm.$(SUFFIX) @@ -81,7 +81,7 @@ BBLAS2OBJS = bgemv.$(SUFFIX) BBLAS1OBJS = bscal.$(SUFFIX) SBBLAS1OBJS = sbdot.$(SUFFIX) SBBLAS2OBJS = sbgemv.$(SUFFIX) -SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) sbgemm_batch.$(SUFFIX) +SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) sbgemm_batch.$(SUFFIX) sbgemm_batch_strided.$(SUFFIX) SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) endif @@ -113,7 +113,7 @@ DBLAS3OBJS = \ dtrsm.$(SUFFIX) dsyrk.$(SUFFIX) dsyr2k.$(SUFFIX) \ domatcopy.$(SUFFIX) dimatcopy.$(SUFFIX)\ dgeadd.$(SUFFIX) dgemmt.$(SUFFIX) dgemmtr.$(SUFFIX) \ - dgemm_batch.$(SUFFIX) + dgemm_batch.$(SUFFIX) dgemm_batch_strided.$(SUFFIX) CBLAS1OBJS = \ caxpy.$(SUFFIX) caxpyc.$(SUFFIX) cswap.$(SUFFIX) \ @@ -143,7 +143,7 @@ CBLAS3OBJS = \ chemm.$(SUFFIX) cherk.$(SUFFIX) cher2k.$(SUFFIX) \ comatcopy.$(SUFFIX) cimatcopy.$(SUFFIX)\ cgeadd.$(SUFFIX) cgemmt.$(SUFFIX) cgemmtr.$(SUFFIX) \ - cgemm_batch.$(SUFFIX) + cgemm_batch.$(SUFFIX) cgemm_batch_strided.$(SUFFIX) ZBLAS1OBJS = \ zaxpy.$(SUFFIX) zaxpyc.$(SUFFIX) zswap.$(SUFFIX) \ @@ -173,7 +173,7 @@ ZBLAS3OBJS = \ zhemm.$(SUFFIX) zherk.$(SUFFIX) zher2k.$(SUFFIX) \ zomatcopy.$(SUFFIX) zimatcopy.$(SUFFIX)\ zgeadd.$(SUFFIX) zgemmt.$(SUFFIX) zgemmtr.$(SUFFIX) \ - zgemm_batch.$(SUFFIX) + zgemm_batch.$(SUFFIX) zgemm_batch_strided.$(SUFFIX) ifeq ($(SUPPORT_GEMM3M), 1) @@ -321,7 +321,7 @@ CSBLAS2OBJS = \ CSBLAS3OBJS = \ cblas_sgemm.$(SUFFIX) cblas_ssymm.$(SUFFIX) cblas_strmm.$(SUFFIX) cblas_strsm.$(SUFFIX) \ cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ - cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) cblas_sgemmtr.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX) + cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) cblas_sgemmtr.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX) cblas_sgemm_batch_strided.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) CBBLAS3OBJS = cblas_bgemm.$(SUFFIX) @@ -329,7 +329,7 @@ CBBLAS2OBJS = cblas_bgemv.$(SUFFIX) CBBLAS1OBJS = cblas_bscal.$(SUFFIX) CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX) -CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(SUFFIX) cblas_sbgemm_batch.$(SUFFIX) +CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(SUFFIX) cblas_sbgemm_batch.$(SUFFIX) cblas_sbgemm_batch_strided.$(SUFFIX) CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) ifeq ($(ONLY_CBLAS),1) CSBEXTOBJS += sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) @@ -357,7 +357,7 @@ CDBLAS2OBJS = \ CDBLAS3OBJS += \ cblas_dgemm.$(SUFFIX) cblas_dsymm.$(SUFFIX) cblas_dtrmm.$(SUFFIX) cblas_dtrsm.$(SUFFIX) \ cblas_dsyrk.$(SUFFIX) cblas_dsyr2k.$(SUFFIX) cblas_domatcopy.$(SUFFIX) cblas_dimatcopy.$(SUFFIX) \ - cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) cblas_dgemmtr.$(SUFFIX) cblas_dgemm_batch.$(SUFFIX) + cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) cblas_dgemmtr.$(SUFFIX) cblas_dgemm_batch.$(SUFFIX) cblas_dgemm_batch_strided.$(SUFFIX) CCBLAS1OBJS = \ cblas_icamax.$(SUFFIX) cblas_icamin.$(SUFFIX) cblas_scasum.$(SUFFIX) cblas_caxpy.$(SUFFIX) \ @@ -382,7 +382,7 @@ CCBLAS3OBJS = \ cblas_csyrk.$(SUFFIX) cblas_csyr2k.$(SUFFIX) \ cblas_chemm.$(SUFFIX) cblas_cherk.$(SUFFIX) cblas_cher2k.$(SUFFIX) \ cblas_comatcopy.$(SUFFIX) cblas_cimatcopy.$(SUFFIX)\ - cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) cblas_cgemmtr.$(SUFFIX) cblas_cgemm_batch.$(SUFFIX) + cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) cblas_cgemmtr.$(SUFFIX) cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch_strided.$(SUFFIX) CXERBLAOBJ = \ cblas_xerbla.$(SUFFIX) @@ -413,7 +413,7 @@ CZBLAS3OBJS = \ cblas_zsyrk.$(SUFFIX) cblas_zsyr2k.$(SUFFIX) \ cblas_zhemm.$(SUFFIX) cblas_zherk.$(SUFFIX) cblas_zher2k.$(SUFFIX)\ cblas_zomatcopy.$(SUFFIX) cblas_zimatcopy.$(SUFFIX) \ - cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) cblas_zgemmtr.$(SUFFIX) cblas_zgemm_batch.$(SUFFIX) + cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) cblas_zgemmtr.$(SUFFIX) cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch_strided.$(SUFFIX) ifeq ($(SUPPORT_GEMM3M), 1) @@ -2544,6 +2544,21 @@ cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) +cblas_sbgemm_batch_strided.$(SUFFIX) cblas_sbgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_sgemm_batch_strided.$(SUFFIX) cblas_sgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_dgemm_batch_strided.$(SUFFIX) cblas_dgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_cgemm_batch_strided.$(SUFFIX) cblas_cgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_zgemm_batch_strided.$(SUFFIX) cblas_zgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + sbgemm_batch.$(SUFFIX) sbgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) @@ -2559,3 +2574,17 @@ cgemm_batch.$(SUFFIX) cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h zgemm_batch.$(SUFFIX) zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) +sbgemm_batch_strided.$(SUFFIX) sbgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +sgemm_batch_strided.$(SUFFIX) sgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +dgemm_batch_strided.$(SUFFIX) dgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +cgemm_batch_strided.$(SUFFIX) cgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +zgemm_batch_strided.$(SUFFIX) zgemm_batch_strided.$(PSUFFIX) : gemm_batch_strided.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) diff --git a/interface/gemm_batch_strided.c b/interface/gemm_batch_strided.c new file mode 100644 index 0000000000..8435b65502 --- /dev/null +++ b/interface/gemm_batch_strided.c @@ -0,0 +1,425 @@ +/***************************************************************************** +Copyright (c) 2025, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +**********************************************************************************/ + +#include +#include +#include "common.h" + +void openblas_warning(int verbose, const char * msg); + +#ifndef COMPLEX +#ifdef XDOUBLE +#define ERROR_NAME "QGEMM_BATCH_STRIDED " +#elif defined(DOUBLE) +#define ERROR_NAME "DGEMM_BATCH_STRIDED " +#define GEMM_BATCH_THREAD dgemm_batch_thread +#else +#define ERROR_NAME "SGEMM_BATCH_STRIDED " +#define GEMM_BATCH_THREAD sgemm_batch_thread +#endif +#else +#ifdef XDOUBLE +#define ERROR_NAME "XGEMM_BATCH_STRIDED " +#elif defined(DOUBLE) +#define ERROR_NAME "ZGEMM_BATCH_STRIDED " +#define GEMM_BATCH_THREAD zgemm_batch_thread +#else +#define ERROR_NAME "CGEMM_BATCH_STRIDED " +#define GEMM_BATCH_THREAD cgemm_batch_thread +#endif +#endif +static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { + GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, + GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, + GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR, + GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC, +}; + +#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) +#define USE_SMALL_MATRIX_OPT 1 +#else +#define USE_SMALL_MATRIX_OPT 0 +#endif + +#if USE_SMALL_MATRIX_OPT +#ifndef DYNAMIC_ARCH +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) +#else +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx])))) +#endif + + +#ifndef COMPLEX +static size_t gemm_small_kernel[] = { + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, +}; + + +static size_t gemm_small_kernel_b0[] = { + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, +}; + +#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) +#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) +#else + +static size_t zgemm_small_kernel[] = { + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, + GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, + GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, +}; + +static size_t zgemm_small_kernel_b0[] = { + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, + GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, + GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, +}; + +#define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) +#define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) +#endif +#endif + +#ifndef CBLAS +void NAME(char *transa, char *transb, + blasint * M, blasint * N, blasint * K, + FLOAT * Alpha, + IFLOAT * a, blasint * Lda, + blasint * stride_a, + IFLOAT *b, blasint * Ldb, + blasint * stride_b, + FLOAT * Beta, + FLOAT * c, blasint * Ldc, blasint * stride_c, blasint * matcount) { + + char ta = *transa; + char tb = *transb; + blasint count = *matcount; + blasint stridea= *stride_a; + blasint strideb= *stride_b; + blasint stridec= *stride_c; + blasint m=*M; + blasint n=*N; + blasint k=*K; + blasint lda=*Lda; + blasint ldb=*Ldb; + blasint ldc=*Ldc; +#if !defined(COMPLEX) + FLOAT alpha=*Alpha; + FLOAT beta=*Beta; +#else + FLOAT *alpha=Alpha; + FLOAT *beta=Beta; +#endif +#else + +void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE transa, enum CBLAS_TRANSPOSE transb, + blasint m, blasint n, blasint k, +#ifndef COMPLEX + FLOAT alpha, + IFLOAT * a, blasint lda, blasint stridea, + IFLOAT * b, blasint ldb, blasint strideb, + FLOAT beta, + FLOAT * c, blasint ldc, blasint stridec, blasint count) { +#else + void * valpha, + void * va, blasint lda, blasint stridea, + void * vb, blasint ldb, blasint strideb, + void * vbeta, + void * vc, blasint ldc, blasint stridec, blasint count) { + + FLOAT * alpha=(FLOAT *)valpha; + FLOAT * beta=(FLOAT *)vbeta; + FLOAT * a=(FLOAT*)va; + FLOAT * b=(FLOAT*)vb; + FLOAT * c=(FLOAT*)vc; +#endif +#endif + BLASLONG group_m, group_n, group_k; + BLASLONG group_lda, group_ldb, group_ldc; + + blas_arg_t * args_array=NULL; + + int mode=0, group_mode=0; + + blasint i=0; + + int group_transa, group_transb; + BLASLONG group_nrowa, group_nrowb; + blasint info; + + void * group_routine=NULL; +#ifdef SMALL_MATRIX_OPT + void * group_small_matrix_opt_routine=NULL; +#endif + +#if defined (SMP) || defined(SMALL_MATRIX_OPT) + double MNK; +#endif + + PRINT_DEBUG_CNAME; + + args_array=(blas_arg_t *)malloc(count * sizeof(blas_arg_t)); + + if(args_array == NULL){ + openblas_warning(0, "memory alloc failed!\n"); + return; + } + +#ifdef SMP +#ifndef COMPLEX +#ifdef XDOUBLE + mode = BLAS_XDOUBLE | BLAS_REAL; +#elif defined(DOUBLE) + mode = BLAS_DOUBLE | BLAS_REAL; +#else + mode = BLAS_SINGLE | BLAS_REAL; +#endif +#else +#ifdef XDOUBLE + mode = BLAS_XDOUBLE | BLAS_COMPLEX; +#elif defined(DOUBLE) + mode = BLAS_DOUBLE | BLAS_COMPLEX; +#else + mode = BLAS_SINGLE | BLAS_COMPLEX; +#endif +#endif +#endif + + for(i=0; i= 0) { + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + free(args_array); + return; + } + + if (group_m == 0 || group_n == 0) continue; + + group_mode=mode; + +#if defined(SMP) || defined(SMALL_MATRIX_OPT) + MNK = (double) group_m * (double) group_n * (double) group_k; +#endif + +#ifdef SMALL_MATRIX_OPT + if (MNK <= 100.0*100.0*100.0){ + group_routine=NULL; +#if !defined(COMPLEX) + if(beta == 0.0){ + group_mode=mode | BLAS_SMALL_B0_OPT; + group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]); + }else{ + group_mode=mode | BLAS_SMALL_OPT; + group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]); + } +#else + if(beta[0] == 0.0 && beta[1] == 0.0){ + group_mode=mode | BLAS_SMALL_B0_OPT; + group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]); + }else{ + group_mode=mode | BLAS_SMALL_OPT; + group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]); + } + +#endif + + }else{ +#endif + group_routine=(void*)(gemm[(group_transb<<2)|group_transa]); +#ifdef SMALL_MATRIX_OPT + } +#endif + + + args_array[i].m=group_m; + args_array[i].n=group_n; + args_array[i].k=group_k; + args_array[i].lda=group_lda; + args_array[i].ldb=group_ldb; + args_array[i].ldc=group_ldc; + args_array[i].alpha=α + args_array[i].beta=β + +#if defined(CBLAS) + if (order == CblasColMajor) { + args_array[i].a=&(a[i*stridea]); + args_array[i].b=&(b[i*strideb]); + }else if(order == CblasRowMajor){ + args_array[i].a=&(b[i*strideb]); + args_array[i].b=&(a[i*stridea]); + } +#else + args_array[i].a=&(a[i*stridea]); + args_array[i].b=&(b[i*strideb]); +#endif + + args_array[i].c= &c[i*stridec]; + + args_array[i].routine_mode=group_mode; + args_array[i].routine=group_routine; +#ifdef SMALL_MATRIX_OPT + if (!group_routine) + args_array[i].routine=group_small_matrix_opt_routine; +#endif + } + + if(count>0) { + GEMM_BATCH_THREAD(args_array,count); + } + + free(args_array); +}