cuda: use strided batched GEMM for linear batch layouts#3704
Open
LwhJesse wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a CUDA strided-batched GEMM fast path for batched matrix multiplication when the batch layout can be represented by a single linear batch stride.
The existing pointer-array batched GEMM path is still used as the fallback. The new path is only selected when the lhs, rhs, and output batch layouts can all be flattened safely into a strided-batched representation.
In particular, this PR:
cublasGemmStridedBatchedExwhen the layout is compatible;gemmBatchedDispatchpath otherwise.Motivation
For compatible batched layouts, the pointer-array batched GEMM path has extra setup overhead: host-side pointer arrays are prepared and copied to the device before the cuBLAS batched GEMM call.
When the same operation can be represented as a strided-batched GEMM, cuBLAS can express the batch layout directly through base pointers and batch strides. This avoids the extra pointer-array setup and host-to-device pointer copies.
Correctness
Tested locally with the CUDA backend on an RTX 4060 Laptop GPU with CUDA 13.2.
The following CUDA BLAS tests passed:
test/blas_cuda: 127/127 passedFocused test filter:
AF_MULTI_GPU_TESTS=0 ./test/blas_cuda \ --gtest_filter='LHSBroadcast/MatrixMultiplyBatch.*:RHSBroadcast/MatrixMultiplyBatch.*:SameBatch/MatrixMultiplyBatch.*:Batched/Gemm.*'No numerical regression was observed in the tested BLAS/GEMM CUDA paths.
Note: the local system required separate CUDA 13.2 / fmt12 / CCCL-Thrust compatibility patches to build ArrayFire. Those compatibility patches are not part of this PR. Both the baseline and optimized builds used the same local compatibility patch stack for the A/B comparison.
Benchmark
Local A/B benchmark was run using two independent worktrees:
492718b5a256d4a9d5198fdce89d8fd21772bfda60d325086Both builds used the same local CUDA 13.2 compatibility patch stack and the same build settings.
Hardware / environment:
Representative batched GEMM cases:
All representative cases were positive in this local benchmark, with speedups from about 3.5% to 10.2%.
Nsight Systems check
I also profiled the
RHSBroadcastcase with Nsight Systems.The main GEMM kernel remained the same in both builds:
ampere_sgemm_128x128_nnThe GPU GEMM kernel time was nearly unchanged:
The difference was in the surrounding setup overhead. The baseline showed a batch of pointer-array setup copies:
cudaMemcpyAsync: 660 callsThe optimized path did not show the same batch of Host-to-Device pointer-array copies.
This matches the intent of the change: the speedup comes from avoiding pointer-array batched GEMM setup when the layout can be represented as strided batched GEMM, not from changing the underlying GEMM kernel itself.
Scope
This PR only changes the CUDA batched GEMM dispatch logic in:
src/backend/cuda/blas.cuIt does not change: