Tiled parallel-scan Triton kernel for the Mamba-2 SSD recurrence. Single kernel, no custom CUDA.
Forward pass: tl.associative_scan with a (gate, value) monoid over tiles of BLOCK_T timesteps.
Backward pass: same scan in reverse via tl.associative_scan(..., reverse=True).
Benchmarked on H200 SXM against mamba_chunk_scan_combined (mamba-ssm).
batch=16, n_heads=8, d_state=64, BLOCK_T=128.
- Forward: 1.64x faster at seqlen=32768
- Forward + backward: 2.42x faster at seqlen=32768
Full results and methodology in blog.md.
flashscan/
reference.py # Naive sequential PyTorch loop (correctness target)
kernel.py # Forward + backward Triton kernels (sequential and parallel)
test.py # Stress tests + torch.autograd.gradcheck
benchmark.py # Timing: naive vs sequential vs parallel vs mamba-ssm
benchmark_v2.py # BLOCK_T sweep + optimized comparison
benchmark_v3.py # Forward+backward timing + numerical stability
profile_kernels.py # Pure CUDA kernel time via torch.profiler
blog.md # Writeup with derivations and all benchmark tables
- PyTorch >= 2.8
- Triton >= 3.4.0
- NVIDIA GPU (tested on H200 SXM, sm_90)
mamba-ssm(for benchmarks only, not required for the kernel)
from flashscan.kernel import flashscan_parallel
# x: (batch, seqlen, n_heads, d_state) float16
# A: (batch, seqlen, n_heads) float32 (log-space decay)
# B: (batch, seqlen, n_heads, d_state) float16
# C: (batch, seqlen, n_heads, d_state) float16
y = flashscan_parallel(x, A, B, C, BLOCK_T=128)
y.sum().backward() # full autograd supportcd flashscan && python test.py