Skip to content

Pupking/04_fused_softmax

Repository files navigation

04_fused_softmax - row-wise softmax on Ampere

A four-step walk from a one-thread-per-row baseline to a single-pass online softmax kernel. Each step moves a Nsight Compute counter, and the online variant doubles as the inner loop Layer 4's Flash Attention kernel reuses. cuDNN ACCURATE and FAST are registered as two baselines.

Results

naive shared warp_shuffle online cudnn_acc cudnn_fast
ms @ 4096×4096 5.878 1.650 1.214 1.421 1.092 1.092
GB/s (2×N·4B) 22.8 81.3 110.5 94.4 122.9 122.9
% cuDNN 19 66 90 77 100 100
step what changed counter that moved step gain
00_naive.cu 1 thread/row, 3 passes over GMEM DRAM 21.56 % baseline
01_shared.cu 1 block/row, SMEM tree-fold for max and sum DRAM 22 % → 36 % 3.56×
02_warp_shuffle.cu __shfl_xor_sync finalise + one-cross-warp SMEM round-trip DRAM 36 % → 64 % 1.36×
03_online.cu fused (max, sum) online pass + branchless warp merge (negative at this size) 0.85×

online runs slower than warp_shuffle here because softmax at this shape is compute-bound on __expf and the online merge does 3 __expf/element vs warp_shuffle's 1. The pattern wins back its cost in Layer 4's fused attention, where the inputs are computed inside the kernel and never hit GMEM — the memory saved is real there, imaginary here. Framed as primitive-introduction, not a perf win.

Experimental Setup

Click for more details cudaGetDeviceProperties / cudaDeviceGetAttribute
  • GPU: NVIDIA GeForce RTX 3050 Laptop GPU (GA107), sm_86, 16 SMs
  • Per-SM: 65,536 registers, 1,536 threads, 100 KB shared memory, 128 KB unified L1/TEX
  • Off-chip: 3.68 GB VRAM, 128-bit bus, 192 GB/s peak DRAM
  • Toolkit / driver: CUDA 13.0.88, driver 580.82.09, compiled -O3 --gpu-architecture=sm_86
  • cuDNN: 9.x (vendored at ../cuda-kernel-portfolio/cudnn, bundled with this repo tree)
  • Shape: rows × cols = 4096 × 4096 (67 M floats = 256 MB working set). Effective bandwidth counts input read + output write = 2·N·4 bytes. At DRAM peak (192 GB/s), this shape is 268 MB / 0.192 GB/ms = 1.39 ms bandwidth-minimal; cudnn_acc achieves 1.09 ms by using the L2 hit path plus a tile-scheduled kernel that overlaps multiple rows per SM.

Summary

Row 0 → Row 1 — parallelise the reductions. naive runs one thread per row, 3 sequential passes over cols elements each. L1/TEX sits at 88.60 % because each thread's read pattern is a 16 KB stripe (cols elements × 4 B) that spills across cache lines; DRAM only sees 21.56 % of peak because most traffic hits and stays in L1. Compute (SM) at 4.83 % shows the block of 256 threads is mostly idle as 256 separate serial loops. shared maps one block per row and reduces in parallel: DRAM climbs 22 → 36 %, Compute (SM) jumps 5 → 58 % (the SMEM tree-fold does real work), and Duration falls 5.88 → 1.65 ms. The L1/TEX traffic drops by the same factor as the thread-level parallelism.

Row 1 → Row 2 — drop the SMEM round-trips. warp_shuffle replaces the per-stride SMEM tree-fold with a warp-shuffle __shfl_xor_sync that runs in registers. A block-size tree-fold did 9 SMEM rounds (log2(512)) plus 9 __syncthreads; warp-shuffle does 5 register shuffles with no barrier, plus one SMEM write-read for the cross-warp merge. DRAM rises 36 → 64 %, L1/TEX falls 58 → 47 %, Compute (SM) falls 58 → 45 % (less SMEM traffic means less pipe use). Duration 1.65 → 1.21 ms.

Row 2 → Row 3 — fused online pass. online computes (max, sum_exp) pairs incrementally, rescaling local_sum by exp(old_max - new_max) each time the running maximum grows. One GMEM read pass computes both statistics; the second pass computes exp(x - max) * inv_sum and writes the output. warp_shuffle does 3 GMEM read passes (max, exp+sum, normalise). At 2048 × 2048 in Layer 4 attention the saved reads dominate; at 4096 × 4096 here with cold GMEM, the saved reads hit L2 (L2 Throughput 22.92 vs warp_shuffle's 40.85 %) and are cheap, while the extra __expf calls per element are the same cost as in a compute-bound kernel. Duration rises 1.21 → 1.42 ms.

Rows 2/3 → cuDNN. cudnn_acc runs softmax_fw_kernel_resident at DRAM 95.10 %, 0.735 ms. Its tile scheduling issues one block per SM hosting many rows (Waves per SM = 64), amortising launch and filling the pipeline deeper than our one-row-per-block design. The kernel fits in 45 KB of SMEM but the tile-per-block strategy keeps 4 blocks/SM resident versus our 1.

Verification

  • Cross-checked against a Kahan-summed, FP64-accumulated CPU reference (audit §L3.1.1/§L3.2.1). The reference cpu_softmax was naive forward- sum and would have had error ~cols * eps ≈ 5e-4 — above the tightened rtol = 1e-4.
  • verify_close uses atol = 1e-6, rtol = 1e-4. For cols = 4096, typical output values are 1/cols ≈ 2.4e-4; the previous tol = 1e-4 absolute would have let a systematic normalizer bug of that size slip through.
  • Output poisoned with NaN before each launch (§0.1).
  • 01_shared.cu's block-size is rounded to the next power-of-2 in the launcher so the stride-halving tree-fold is correct for any cols (audit §L3.1.1 — the reference kernel only worked for power-of-2 block sizes but was launched with ceil(cols/32) * 32, which was buggy for cols ∈ {96, 160, ...}).
  • 03_online.cu uses the branchless (max, sum) merge and skips elements with val == -INFINITY (audit §L3.1.2 — the reference if (val > local_max) branch computed expf(NaN) on rows of all -INF, the masked-attention corner case; Layer 4 would have inherited the bug).

Reproducing

Build:

rm -rf build && mkdir build && cd build
cmake .. && cmake --build . --parallel
cd ..

Run the Layer-3 sweep:

export LD_LIBRARY_PATH=PATH_TO_CUDNN/cudnn/lib:$LD_LIBRARY_PATH
./build/bin/softmax_bench --rows 4096 --cols 4096 --iters 20 --runs 5 --warmup 3

Capture profiles:

./scripts/profile_layer4.sh

Scope

  • FP32 only, power-of-2 cols not required. 01_shared.cu rounds block size to the next pow2 in the launcher; 02_warp_shuffle.cu and 03_online.cu take cols multiple of 32.
  • Single shape in the benchmark. rows = cols = 4096. Transformer attention softmax has rows = B · H · seq_len and cols = seq_len; at seq_len = 2048 with 16 heads the effective rows = O(1e5) — our variants scale linearly in rows, but the Pareto optimum may shift to online if rows dominate and the extra memory traffic of the 3-pass variants hurts.
  • No masking. 03_online.cu defends against -INF inputs so masked attention works; this is intent for Layer 4, not tested directly in this layer's bench (the inputs are uniform).

About

Row-wise CUDA softmax kernels: shared-memory reduction, warp-shuffle reduction, and online softmax benchmarked against cuDNN.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors