Skip to content

Add sparse MLA forward op in experimental#91

Merged
hannahli-nv merged 5 commits intoNVIDIA:mainfrom
Weili-0234:feature/sparse-mla
Apr 2, 2026
Merged

Add sparse MLA forward op in experimental#91
hannahli-nv merged 5 commits intoNVIDIA:mainfrom
Weili-0234:feature/sparse-mla

Conversation

@Weili-0234
Copy link
Copy Markdown
Contributor

Description

Add kernel implementation, test cases and benchmarks for sparse MLA (Multi-Latent Attention) forward. This is the sparse variant of the existing dense MLA op where each query position attends only to top-k KV entries specified by an index tensor, rather than all S_kv positions. This attention pattern is used in DeepSeek V3.2 model.

Testing

17 test cases covering:

  • Basic shapes, GQA (group_size=4), topk==S_kv, irregular shapes
  • 8-row forced-config correctness matrix via kernel_configs (TILE_H ∈ {1,2,4,16}, non-pow2 group sizes)
  • 6 invalid-config rejection tests (partial configs, non-pow2 values, constraint violations)
DISABLE_AUTOTUNE=1 pytest tests/ops/experimental/test_sparse_mla.py -v
============================== 17 passed in 1.80s ==============================

Performance

Benchmarked on RTX 5090 with B=1, S=256, S_kv=4096, H=16, H_kv=1, D=128, D_PE=64:

sparse-mla-topk-scaling-bfloat16-TFLOPS:
     topk  CuTile (TFLOPS)  PyTorch (TFLOPS)
0    64.0         2.117587          0.122518
1   128.0         2.649377          0.125455
2   256.0         2.893056          0.127802
3   512.0         3.040623          0.128981
4  1024.0         3.113690          0.129510
5  2048.0         3.165740          0.129224

CI Configuration

config:
  build: true
  # valid options are "ops", "benchmark", and "sanity"
  test: [ops, benchmark]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

Implements a sparse variant of the dense MLA operator where each query
position attends only to top-k KV entries specified by an index tensor.

Features TILE_H multi-head-per-block tiling that amortizes K/V/KPE
gather cost across heads (~2.9x speedup), 3-path config dispatch
(explicit kernel_configs / DISABLE_AUTOTUNE / autotune), and config
validation with GQA alignment constraints.
Covers basic shapes, GQA, topk==S_kv, irregular shapes, 8-row
forced-config correctness matrix (TILE_H values, non-pow2 group sizes),
and invalid kernel_configs rejection tests. 17 test cases total.
CuTile vs PyTorch benchmark sweeping topk with TFLOPS reporting.
Includes correctness validation at small topk values.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 31, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test ea2bcdf

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test 81b89f8

)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("backend", _backends)
def test_basic(self, B, H, S, S_kv, D, D_PE, H_kv, topk, dtype, backend, arch):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! The kernel implementation and tests look solid overall. One issue that needs fixing before merge:
The CI workflow runs ops tests with pytest -s tests/ops tests/suites -v -k test_op
Your current test methods — test_basic, test_gqa, test_topk_equals_skv, test_irregular_shapes, test_tile_h_configs, test_invalid_kernel_configs will all be silently skipped by CI because none of them match the test_op pattern.
Please rename your test methods to use the test_op_ prefix, e.g.: test_basic → test_op_basic.
Sorry for the inconvenience.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful review! I've renamed the test cases as required.

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test 33bc2b5

@hannahli-nv hannahli-nv merged commit b1a81d3 into NVIDIA:main Apr 2, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants