Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[Perf]Call cublas.Gemm instead of cublasGemmStridedBatched (#1407)
Browse files Browse the repository at this point in the history
* [Perf]Call cublas.Gemm instead of cublasGemmStridedBatched

* fix unittest
  • Loading branch information
Aurelius84 committed May 12, 2023
1 parent 9314789 commit 06424c2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
59 changes: 37 additions & 22 deletions cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ void cinn_call_cublas(void *v_args,
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
cudaStream_t custream = static_cast<cudaStream_t>(stream);
CUBLAS_CALL(cublasSetStream(cuhandle, custream));
VLOG(3) << "a1 ~ a4: " << a1 << " " << a2 << " " << a3 << " " << a4;
VLOG(3) << "b1 ~ b4: " << b1 << " " << b2 << " " << b3 << " " << b4;
VLOG(3) << "trans_a: " << trans_a << ", trans_b: " << trans_b << ", trans_o: " << trans_o;

void *A = args[0].operator cinn_buffer_t *()->memory;
void *B = args[1].operator cinn_buffer_t *()->memory;
Expand Down Expand Up @@ -167,32 +170,42 @@ void cinn_call_cublas(void *v_args,
}

if (a1 * a2 * b1 * b2 == 1) {
VLOG(3) << "call cublasGemm for a1 * a2 * b1 * b2 == 1";
CUBLAS_CALL(
cublasGemm(cuda_dtype, cuhandle, trans_op_l, trans_op_r, m, n, k, alpha, lhs, ldl, rhs, ldr, beta, C, ldc));
} else if (a1 * b1 == 1) {
CHECK(a2 == b2 || a2 == 1 || b2 == 1);
int stride_l = trans_o ? (a2 > 1 ? a3 * a4 : 0) : (b2 > 1 ? b3 * b4 : 0);
int stride_r = trans_o ? (b2 > 1 ? b3 * b4 : 0) : (a2 > 1 ? a3 * a4 : 0);
int batch = std::max(a2, b2);
CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype,
cuhandle,
trans_op_l,
trans_op_r,
m,
n,
k,
alpha,
lhs,
ldl,
stride_l,
rhs,
ldr,
stride_r,
beta,
C,
ldc,
m * n,
batch));
if (b2 == 1 && trans_op_r == CUBLAS_OP_N) {
// In case of [1, bs, M, K] * [1, 1, K, N]
VLOG(3) << "call cublasGemm for a1 * b1 = 1, b2 = 1, trans_op_r:" << trans_op_r;
CUBLAS_CALL(cublasGemm(
cuda_dtype, cuhandle, trans_op_l, trans_op_r, m, a2 * n, k, alpha, lhs, ldl, A, ldr, beta, C, ldc));
} else {
int stride_l = trans_o ? (a2 > 1 ? a3 * a4 : 0) : (b2 > 1 ? b3 * b4 : 0);
int stride_r = trans_o ? (b2 > 1 ? b3 * b4 : 0) : (a2 > 1 ? a3 * a4 : 0);
int batch = std::max(a2, b2);
VLOG(3) << "call cublasGemmStridedBatched with a1*b1 = 1, stride_l = " << stride_l << ", stride_r = " << stride_r
<< ", batch = " << batch;
CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype,
cuhandle,
trans_op_l,
trans_op_r,
m,
n,
k,
alpha,
lhs,
ldl,
stride_l,
rhs,
ldr,
stride_r,
beta,
C,
ldc,
m * n,
batch));
}
} else {
int l1 = trans_o ? a1 : b1, l2 = trans_o ? a2 : b2, l3 = trans_o ? a3 : b3, l4 = trans_o ? a4 : b4;
int r1 = trans_o ? b1 : a1, r2 = trans_o ? b2 : a2, r3 = trans_o ? b3 : a3, r4 = trans_o ? b4 : a4;
Expand All @@ -204,6 +217,8 @@ void cinn_call_cublas(void *v_args,
// four types matmul:
// (N, L) * (N, L) , (N, 1) * (N, 1)
// (N, L) * (1, 1) , (1, 1) * (N, L)
VLOG(3) << "call cublasGemmStridedBatched for stride_l = " << stride_l << ", stride_r = " << stride_r
<< ", batch = " << std::max(l1, r1) * std::max(l2, r2);
CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype,
cuhandle,
trans_op_l,
Expand Down
13 changes: 7 additions & 6 deletions python/tests/ops/test_matmul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def cinn_func(self, builder, x, y):

def build_cinn_program(self, target):
builder = NetBuilder("matmul")
x = builder.create_input(Float(32), self.inputs["x"].shape, "x")
y = builder.create_input(Float(32), self.inputs["y"].shape, "y")
dtype = self.nptype2cinntype(self.inputs["x"].dtype)
x = builder.create_input(dtype, self.inputs["x"].shape, "x")
y = builder.create_input(dtype, self.inputs["y"].shape, "y")
out = self.cinn_func(builder, x, y)

prog = builder.build()
Expand Down Expand Up @@ -145,8 +146,8 @@ def init_case(self):
class TestMatmulCase8(TestMatmulOp):
def init_case(self):
self.inputs = {
"x": np.random.random([8, 16, 4]).astype("float32"),
"y": np.random.random([1, 4, 16]).astype("float32")
"x": np.random.random([8, 16, 4]).astype("float16"),
"y": np.random.random([1, 4, 16]).astype("float16")
}
self.transpose_x = False
self.transpose_y = False
Expand All @@ -165,8 +166,8 @@ def init_case(self):
class TestMatmulCase10(TestMatmulOp):
def init_case(self):
self.inputs = {
"x": np.random.random([8, 16, 4]).astype("float32"),
"y": np.random.random([4, 16]).astype("float32")
"x": np.random.random([8, 16, 4]).astype("float16"),
"y": np.random.random([4, 16]).astype("float16")
}
self.transpose_x = False
self.transpose_y = False
Expand Down

0 comments on commit 06424c2

Please sign in to comment.