Add sparse MLA forward op in experimental#91
Conversation
Implements a sparse variant of the dense MLA operator where each query position attends only to top-k KV entries specified by an index tensor. Features TILE_H multi-head-per-block tiling that amortizes K/V/KPE gather cost across heads (~2.9x speedup), 3-path config dispatch (explicit kernel_configs / DISABLE_AUTOTUNE / autotune), and config validation with GQA alignment constraints.
Covers basic shapes, GQA, topk==S_kv, irregular shapes, 8-row forced-config correctness matrix (TILE_H values, non-pow2 group sizes), and invalid kernel_configs rejection tests. 17 test cases total.
CuTile vs PyTorch benchmark sweeping topk with TFLOPS reporting. Includes correctness validation at small topk values.
|
/ok to test ea2bcdf |
|
/ok to test 81b89f8 |
| ) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("backend", _backends) | ||
| def test_basic(self, B, H, S, S_kv, D, D_PE, H_kv, topk, dtype, backend, arch): |
There was a problem hiding this comment.
Thanks for the contribution! The kernel implementation and tests look solid overall. One issue that needs fixing before merge:
The CI workflow runs ops tests with pytest -s tests/ops tests/suites -v -k test_op
Your current test methods — test_basic, test_gqa, test_topk_equals_skv, test_irregular_shapes, test_tile_h_configs, test_invalid_kernel_configs will all be silently skipped by CI because none of them match the test_op pattern.
Please rename your test methods to use the test_op_ prefix, e.g.: test_basic → test_op_basic.
Sorry for the inconvenience.
There was a problem hiding this comment.
Thanks for the careful review! I've renamed the test cases as required.
|
/ok to test 33bc2b5 |
Description
Add kernel implementation, test cases and benchmarks for sparse MLA (Multi-Latent Attention) forward. This is the sparse variant of the existing dense MLA op where each query position attends only to top-k KV entries specified by an index tensor, rather than all S_kv positions. This attention pattern is used in DeepSeek V3.2 model.
Testing
17 test cases covering:
kernel_configs(TILE_H ∈ {1,2,4,16}, non-pow2 group sizes)Performance
Benchmarked on RTX 5090 with B=1, S=256, S_kv=4096, H=16, H_kv=1, D=128, D_PE=64:
CI Configuration
Checklist
./format.sh)