Skip to content

Commit

Permalink
Implement cublasSgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
mrdomino committed Dec 12, 2023
1 parent 4fb0813 commit ffb039d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
2 changes: 1 addition & 1 deletion llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

#if defined(GGML_USE_NAIVE)

// #define cublasSgemm_v2 cublasSgemm_v2_
#define cublasSgemm_v2 cublasSgemm_v2_
#define cublasGemmEx cublasGemmEx_
#define cublasGemmStridedBatchedEx cublasGemmStridedBatchEx_
#include "naive-gemm.cu"
Expand Down
42 changes: 38 additions & 4 deletions llama.cpp/naive-gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <cublas_v2.h>

#define READ(A, trans, ld, i, j) \
__half2float(((trans) == CUBLAS_OP_N) \
? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)])
(((trans) == CUBLAS_OP_N) ? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)])
#define READ16(...) __half2float(READ(__VA_ARGS__))

static __device__ __forceinline__ void matmul(int m, int n, int k,
const half *A, int lda,
Expand All @@ -15,8 +15,8 @@ static __device__ __forceinline__ void matmul(int m, int n, int k,
float sum = 0.0;
half *cptr = C + i + j * ldc;
for (int l = 0; l < k; ++l) {
sum += READ(A, CUBLAS_OP_T, lda, i, l) *
READ(B, CUBLAS_OP_N, ldb, l, j);
sum += READ16(A, CUBLAS_OP_T, lda, i, l) *
READ16(B, CUBLAS_OP_N, ldb, l, j);
}
*cptr = __float2half(sum);
}
Expand All @@ -28,6 +28,21 @@ static __global__ void wrap_matmul(int m, int n, int k, const half *A, int lda,
matmul(m, n, k, A, lda, B, ldb, C, ldc);
}

static __global__ void matmul32(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
float sum = 0.0;
float *cptr = C + i + j * ldc;
for (int l = 0; l < k; ++l) {
sum += READ(A, CUBLAS_OP_T, lda, i, l) *
READ(B, CUBLAS_OP_N, ldb, l, j);
}
*cptr = sum;
}
}
}

static bool check_args(cublasOperation_t transa, cublasOperation_t transb,
const void *pAlpha, cudaDataType_t Atype,
cudaDataType_t Btype, const void *pBeta,
Expand All @@ -39,6 +54,25 @@ static bool check_args(cublasOperation_t transa, cublasOperation_t transb,
__half2float(*(half *)pBeta) == 0.0f;
}

cublasStatus_t cublasSgemm_v2(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc) {
if (transa != CUBLAS_OP_T || transb != CUBLAS_OP_N ||
*alpha != 1.0f || *beta != 0.0f) {
return CUBLAS_STATUS_NOT_SUPPORTED;
}
cudaStream_t stream;
cublasGetStream(handle, &stream);
matmul32<<<1, 1, 0, stream>>>(m, n, k, A, lda, B, ldb, C, ldc);
return CUBLAS_STATUS_SUCCESS;
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex

cublasStatus_t cublasGemmEx(cublasHandle_t handle,
Expand Down

0 comments on commit ffb039d

Please sign in to comment.