Skip to content

cuda: use strided batched GEMM for linear batch layouts#3704

Open
LwhJesse wants to merge 1 commit into
arrayfire:masterfrom
LwhJesse:pr/cuda-strided-batched-gemm-60d325086
Open

cuda: use strided batched GEMM for linear batch layouts#3704
LwhJesse wants to merge 1 commit into
arrayfire:masterfrom
LwhJesse:pr/cuda-strided-batched-gemm-60d325086

Conversation

@LwhJesse
Copy link
Copy Markdown

Summary

This PR adds a CUDA strided-batched GEMM fast path for batched matrix multiplication when the batch layout can be represented by a single linear batch stride.

The existing pointer-array batched GEMM path is still used as the fallback. The new path is only selected when the lhs, rhs, and output batch layouts can all be flattened safely into a strided-batched representation.

In particular, this PR:

  • computes effective batch strides for inputs, including broadcasted batch dimensions;
  • checks whether the 2D batch space can be represented by one linear batch stride;
  • uses cublasGemmStridedBatchedEx when the layout is compatible;
  • falls back to the existing pointer-array gemmBatchedDispatch path otherwise.

Motivation

For compatible batched layouts, the pointer-array batched GEMM path has extra setup overhead: host-side pointer arrays are prepared and copied to the device before the cuBLAS batched GEMM call.

When the same operation can be represented as a strided-batched GEMM, cuBLAS can express the batch layout directly through base pointers and batch strides. This avoids the extra pointer-array setup and host-to-device pointer copies.

Correctness

Tested locally with the CUDA backend on an RTX 4060 Laptop GPU with CUDA 13.2.

The following CUDA BLAS tests passed:

  • test/blas_cuda: 127/127 passed
  • focused batched/broadcast tests: 47/47 passed

Focused test filter:

AF_MULTI_GPU_TESTS=0 ./test/blas_cuda \
  --gtest_filter='LHSBroadcast/MatrixMultiplyBatch.*:RHSBroadcast/MatrixMultiplyBatch.*:SameBatch/MatrixMultiplyBatch.*:Batched/Gemm.*'

No numerical regression was observed in the tested BLAS/GEMM CUDA paths.

Note: the local system required separate CUDA 13.2 / fmt12 / CCCL-Thrust compatibility patches to build ArrayFire. Those compatibility patches are not part of this PR. Both the baseline and optimized builds used the same local compatibility patch stack for the A/B comparison.

Benchmark

Local A/B benchmark was run using two independent worktrees:

  • baseline: parent of this commit, 492718b5a256d4a9d5198fdce89d8fd21772bfda
  • optimized: this PR commit, 60d325086

Both builds used the same local CUDA 13.2 compatibility patch stack and the same build settings.

Hardware / environment:

  • GPU: NVIDIA GeForce RTX 4060 Laptop GPU
  • Driver: 595.71.05
  • CUDA: 13.2
  • Build type: Release
  • Warmup: 20 iterations
  • Timed iterations: 200

Representative batched GEMM cases:

Case Baseline Optimized Speedup
SameBatch 0.147200 ms 0.140888 ms 1.0448x
LHSBroadcast 0.147233 ms 0.140849 ms 1.0453x
RHSBroadcast 0.147073 ms 0.140907 ms 1.0438x
IrregularBatch 0.193487 ms 0.186971 ms 1.0349x
IrregularBroadcast 0.076449 ms 0.069351 ms 1.1023x

All representative cases were positive in this local benchmark, with speedups from about 3.5% to 10.2%.

Nsight Systems check

I also profiled the RHSBroadcast case with Nsight Systems.

The main GEMM kernel remained the same in both builds:

  • ampere_sgemm_128x128_nn

The GPU GEMM kernel time was nearly unchanged:

  • baseline GEMM kernel total: about 28.99 ms
  • optimized GEMM kernel total: about 28.95 ms

The difference was in the surrounding setup overhead. The baseline showed a batch of pointer-array setup copies:

  • cudaMemcpyAsync: 660 calls
  • Host-to-Device memcpy total: about 355.8 us

The optimized path did not show the same batch of Host-to-Device pointer-array copies.

This matches the intent of the change: the speedup comes from avoiding pointer-array batched GEMM setup when the layout can be represented as strided batched GEMM, not from changing the underlying GEMM kernel itself.

Scope

This PR only changes the CUDA batched GEMM dispatch logic in:

  • src/backend/cuda/blas.cu

It does not change:

  • CPU backend behavior
  • OpenCL backend behavior
  • oneAPI backend behavior
  • cuBLAS GEMM kernel implementation
  • CUDA 13 / fmt12 / CCCL-Thrust compatibility code
  • general ArrayFire build system behavior

@LwhJesse LwhJesse marked this pull request as ready for review May 12, 2026 19:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant