Skip to content

alityb/flashscan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 

Repository files navigation

FlashScan

Tiled parallel-scan Triton kernel for the Mamba-2 SSD recurrence. Single kernel, no custom CUDA.

Forward pass: tl.associative_scan with a (gate, value) monoid over tiles of BLOCK_T timesteps. Backward pass: same scan in reverse via tl.associative_scan(..., reverse=True).

Performance

Benchmarked on H200 SXM against mamba_chunk_scan_combined (mamba-ssm). batch=16, n_heads=8, d_state=64, BLOCK_T=128.

  • Forward: 1.64x faster at seqlen=32768
  • Forward + backward: 2.42x faster at seqlen=32768

Full results and methodology in blog.md.

Files

flashscan/
  reference.py      # Naive sequential PyTorch loop (correctness target)
  kernel.py         # Forward + backward Triton kernels (sequential and parallel)
  test.py           # Stress tests + torch.autograd.gradcheck
  benchmark.py      # Timing: naive vs sequential vs parallel vs mamba-ssm
  benchmark_v2.py   # BLOCK_T sweep + optimized comparison
  benchmark_v3.py   # Forward+backward timing + numerical stability
  profile_kernels.py # Pure CUDA kernel time via torch.profiler
  blog.md           # Writeup with derivations and all benchmark tables

Requirements

  • PyTorch >= 2.8
  • Triton >= 3.4.0
  • NVIDIA GPU (tested on H200 SXM, sm_90)
  • mamba-ssm (for benchmarks only, not required for the kernel)

Usage

from flashscan.kernel import flashscan_parallel

# x: (batch, seqlen, n_heads, d_state) float16
# A: (batch, seqlen, n_heads) float32 (log-space decay)
# B: (batch, seqlen, n_heads, d_state) float16
# C: (batch, seqlen, n_heads, d_state) float16
y = flashscan_parallel(x, A, B, C, BLOCK_T=128)
y.sum().backward()  # full autograd support

Tests

cd flashscan && python test.py

About

triton implementation of mamba-2 ssd layer

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages