experimental/swa_attention_cutile#107
Conversation
There was a problem hiding this comment.
Pull request overview
Adds an experimental sliding-window attention (SWA) prefill operator backed by a fused cuTile kernel, plus test + benchmark scaffolding and dispatch wiring to expose it through tilegym.ops.
Changes:
- Introduces a cuTile SWA forward kernel + Python launcher, including an optional HuggingFace Mistral monkey-patch wrapper.
- Adds correctness tests for SWA across multiple sequence/window configurations.
- Adds a Triton
perf_reportbenchmark and wires the op into the unifiedtilegym.opsdispatcher and cuTile exports.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/ops/experimental/test_swa_attention.py | New SWA correctness tests against a PyTorch reference implementation. |
| tests/benchmark/experimental/bench_swa_attention.py | New Triton perf_report benchmark comparing cuTile SWA vs a PyTorch reference backend. |
| src/tilegym/ops/ops.py | Adds the @dispatch("swa_attention") op API entry point. |
| src/tilegym/ops/cutile/experimental/swa_attention.py | Implements the cuTile SWA prefill kernel + host launcher + HF integration helpers. |
| src/tilegym/ops/cutile/init.py | Exposes the experimental SWA implementation in the cuTile backend package exports. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def swa_fmha_wrapper(module, q, k, v, attention_mask=None, dropout=0.0, scaling=None, is_causal=None, **kwargs): | ||
| if scaling is None: | ||
| scaling = 1.0 / math.sqrt(q.size(-1)) | ||
| if is_causal is None: | ||
| is_causal = True | ||
|
|
||
| # decode (single token) -- our kernel is a prefill kernel, so we | ||
| # fall back to PyTorch SDPA for autoregressive decode steps. | ||
| # also need to expand KV heads for GQA since SDPA expects matched dims. | ||
| if q.size(-2) == 1: | ||
| if k.size(1) != q.size(1): | ||
| n_rep = q.size(1) // k.size(1) | ||
| k = k.repeat_interleave(n_rep, dim=1) | ||
| v = v.repeat_interleave(n_rep, dim=1) | ||
| # cuDNN backend can fail on some GPUs (e.g. B300), try flash then math | ||
| for be in [torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.MATH]: | ||
| try: | ||
| with torch.nn.attention.sdpa_kernel(be): | ||
| o = F.scaled_dot_product_attention(q, k, v, is_causal=False) | ||
| return o.transpose(1, 2).contiguous(), None | ||
| except RuntimeError: | ||
| continue | ||
| raise RuntimeError("no working SDPA backend for decode") | ||
|
|
||
| # prefill -- try to read window size from the model's config (e.g. | ||
| # MistralConfig.sliding_window), fall back to the user-supplied default | ||
| w = getattr(getattr(module, "config", None), "sliding_window", None) | ||
| if w is None or w is False: | ||
| w = window_size | ||
| if w is None: | ||
| w = k.size(-2) # no window at all, full causal | ||
|
|
||
| from tilegym.ops import swa_attention as _swa | ||
|
|
||
| o = _swa(q.half(), k.half(), v.half(), window_size=w, scaling=scaling, is_causal=is_causal, backend=backend) | ||
| return o.transpose(1, 2).contiguous(), None |
There was a problem hiding this comment.
The HF wrapper accepts attention_mask and dropout but ignores both in the prefill path, which will produce incorrect results for padded batches and for any training/inference configuration with non-zero dropout. Consider either (a) explicitly rejecting unsupported cases (raise with a clear message when attention_mask is not None or dropout != 0), or (b) falling back to PyTorch SDPA for those cases.
| # also need to expand KV heads for GQA since SDPA expects matched dims. | ||
| if q.size(-2) == 1: | ||
| if k.size(1) != q.size(1): | ||
| n_rep = q.size(1) // k.size(1) |
There was a problem hiding this comment.
In the decode path, n_rep = q.size(1) // k.size(1) assumes the head counts are divisible. If they are not, repeat_interleave will either do the wrong thing or error later. Add a divisibility check and raise a clear error when q.size(1) % k.size(1) != 0 (and similarly ensure k.size(1) <= q.size(1)).
| n_rep = q.size(1) // k.size(1) | |
| q_heads = q.size(1) | |
| kv_heads = k.size(1) | |
| if kv_heads > q_heads: | |
| raise ValueError( | |
| f"decode path requires k/v head count <= q head count, got q_heads={q_heads}, kv_heads={kv_heads}" | |
| ) | |
| if q_heads % kv_heads != 0: | |
| raise ValueError( | |
| f"decode path requires q head count to be divisible by k/v head count, got q_heads={q_heads}, kv_heads={kv_heads}" | |
| ) | |
| n_rep = q_heads // kv_heads |
| import math | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F |
There was a problem hiding this comment.
torch.nn.functional as F is imported but never used in this benchmark file. Remove the unused import to avoid lint noise and keep the benchmark minimal.
| import torch.nn.functional as F |
| device = torch.device("cuda") | ||
| q = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) | ||
| k = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) | ||
| v = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) |
There was a problem hiding this comment.
This test unconditionally uses a CUDA device (torch.device("cuda")) but never checks torch.cuda.is_available(). On environments where cuda.tile imports successfully but no CUDA device is available, this will raise before the test can be skipped. Add an explicit CUDA availability skip (as done in other ops tests) before creating CUDA tensors.
| @pytest.mark.parametrize("backend", _backends) | ||
| def test_long_context_mistral(self, backend): | ||
| # Mistral-style: 8K context, 4K window | ||
| self._run_test(B=1, H=1, S=8192, D=128, W=4096, dtype=torch.float16, backend=backend) |
There was a problem hiding this comment.
test_long_context_mistral runs the fp32 reference which materializes an 8192x8192 attention matrix (and does a full QK^T matmul). This is extremely expensive and is likely to cause CI timeouts or OOMs. Consider marking this test as @pytest.mark.slow (or gating it behind an env/marker) and/or reducing the reference workload (e.g., smaller S for unit tests, or a banded O(S*W) reference implementation).
|
Hi @DevTechJr, thanks for the contribution! We have received your CLA file and will review your PR later. |
372232d to
3fe1482
Compare
|
/ok to test 3fe1482 |
| # -- basic correctness -- | ||
|
|
||
| @pytest.mark.parametrize("backend", _backends) | ||
| def test_window_equals_seq(self, backend): |
There was a problem hiding this comment.
Change the test function to start with test_op_xxx
There was a problem hiding this comment.
Renamed function accordingly.
|
|
||
|
|
||
| def tile_swa_attention(q, k, v, window_size, scaling=None, is_causal=True, **kwargs): | ||
| # q: (B, H, S_Q, D), k/v: (B, H_K, S_K, D) -- fp16 |
There was a problem hiding this comment.
nit: looks this func only support fp16. suggest check dtype at beginning:
if q.dtype not in (torch.float16,):
raise ValueError(f"SWA kernel requires fp16 input, got {q.dtype}")
There was a problem hiding this comment.
Thank you for the suggestion. Added the dtype check with ValueError for non-fp16 inputs.
| modeling_mistral.apply_rotary_pos_emb = get_apply_rope_func(model="llama") | ||
| modeling_mistral.MistralRMSNorm = get_rms_norm_module() | ||
| modeling_mistral.MistralMLP = get_swiglu_module() | ||
| except Exception: |
There was a problem hiding this comment.
nit:
except ImportError: pass
except Exception: raise
| return torch.matmul(torch.softmax(scores, dim=-1), v.float()).to(q.dtype) | ||
|
|
||
|
|
||
| class TestSWAAttention(common.PyTestCase): |
There was a problem hiding this comment.
if this func supports GQA, pls add a test_op_gqa.
56f2859 to
dcf7262
Compare
|
Hi @azazhu , thank you for the comments. I have resolved feedback from both you and Copilot in my new push. Please test and let me know next steps accordingly. |
| ("cutile", "CuTile SWA", ("blue", "-")) if is_backend_available("cutile") else None, | ||
| ("torch", "PyTorch ref", ("green", "-")), |
There was a problem hiding this comment.
Hi @DevTechJr , I notice that the display names here contain spaces ("CuTile SWA", "PyTorch ref"), which breaks the CI summary page. Please match the convention rename to single-token names:
| ("cutile", "CuTile SWA", ("blue", "-")) if is_backend_available("cutile") else None, | |
| ("torch", "PyTorch ref", ("green", "-")), | |
| ("cutile", "CuTile", ("blue", "-")) if is_backend_available("cutile") else None, | |
| ("torch", "PyTorch", ("green", "-")), |
Sorry for the inconvenience.
|
/ok to test 9c71b16 |
|
/ok to test d58112c |
Description
Adds an experimental sliding window attention (SWA) prefill kernel. This operator is not currently in TileGym and is needed for native SWA support in Mistral, Mixtral, and Gemma 3.
Kernel
Single fused cuTile kernel with online softmax. Each query attends to at most W preceding keys. The kernel skips KV blocks outside the window entirely, giving O(S*W) instead of O(S^2).
bid(0)= Q tile block,bid(1)= batch * headParameters
Autotuned on B300 SXM6 AC (sm_103) with two rounds:
Winner:
TILE_M=64 TILE_N=128 occupancy=2 precision=fastPerformance
SWA vs Flash Attention (PyTorch SDPA), B=1 H=32 D=128 W=4096, NVBench cold measurements:
Cosine similarity 1.00000 vs fp32 reference at all tested lengths. Noise < 1% on all NVBench runs.
Files
New:
src/tilegym/ops/cutile/experimental/swa_attention.py- kernel, host launcher, HF model integrationtests/ops/experimental/test_swa_attention.py- 14 correctness tests (basic, edge cases, various shapes, 8K Mistral context)tests/benchmark/experimental/bench_swa_attention.py- Triton perf_report benchmarkModified:
src/tilegym/ops/ops.py- added@dispatch("swa_attention")src/tilegym/ops/cutile/__init__.py- added experimental imports and__all__entriesHF model integration
Includes
apply_tilegym_swa_to_mistral()which follows the same monkey-patch pattern as the existingapply_tilegym_kernel_to_mistralinmonkey_patch.py. ReplacesALL_ATTENTION_FUNCTIONS["sdpa"]with an SWA-aware wrapper. Tested with Mistral-7B-Instruct-v0.3 end-to-end generation.What this unblocks
From the roadmap:
CI Configuration
Checklist
./format.sh)