Skip to content

experimental/swa_attention_cutile#107

Merged
hannahli-nv merged 4 commits into
NVIDIA:mainfrom
DevTechJr:experimental/swa_attention_cutile
Apr 23, 2026
Merged

experimental/swa_attention_cutile#107
hannahli-nv merged 4 commits into
NVIDIA:mainfrom
DevTechJr:experimental/swa_attention_cutile

Conversation

@DevTechJr
Copy link
Copy Markdown
Contributor

@DevTechJr DevTechJr commented Apr 17, 2026

Description

Adds an experimental sliding window attention (SWA) prefill kernel. This operator is not currently in TileGym and is needed for native SWA support in Mistral, Mixtral, and Gemma 3.

Kernel

Single fused cuTile kernel with online softmax. Each query attends to at most W preceding keys. The kernel skips KV blocks outside the window entirely, giving O(S*W) instead of O(S^2).

  • 2D grid: bid(0) = Q tile block, bid(1) = batch * head
  • exp2 + flush-to-zero softmax (maps to SFU hardware)
  • Combined trailing window + causal + seq bounds mask in one pass
  • GQA support via KV head expansion
  • Decode falls back to PyTorch SDPA (this is a prefill kernel)

Parameters

Autotuned on B300 SXM6 AC (sm_103) with two rounds:

  1. In-house grid search over 54 tile/occupancy/precision configs, hardware-pruned to 42, quality-gated at cosine sim >= 0.99 vs fp32 reference, median of 200 CUDA event trials with top-3 re-verification at 500 trials
  2. NVBench (cuda-bench) cold validation with L2 cache flushing, automatic noise detection, and GPU clock locking

Winner: TILE_M=64 TILE_N=128 occupancy=2 precision=fast

Performance

SWA vs Flash Attention (PyTorch SDPA), B=1 H=32 D=128 W=4096, NVBench cold measurements:

Seq SWA (us) SDPA (us) Speedup
4096 258 461 1.78x
8192 678 1571 2.32x
16384 1542 5781 3.75x
32768 3295 22100 6.71x

Cosine similarity 1.00000 vs fp32 reference at all tested lengths. Noise < 1% on all NVBench runs.

Files

New:

  • src/tilegym/ops/cutile/experimental/swa_attention.py - kernel, host launcher, HF model integration
  • tests/ops/experimental/test_swa_attention.py - 14 correctness tests (basic, edge cases, various shapes, 8K Mistral context)
  • tests/benchmark/experimental/bench_swa_attention.py - Triton perf_report benchmark

Modified:

  • src/tilegym/ops/ops.py - added @dispatch("swa_attention")
  • src/tilegym/ops/cutile/__init__.py - added experimental imports and __all__ entries

HF model integration

Includes apply_tilegym_swa_to_mistral() which follows the same monkey-patch pattern as the existing apply_tilegym_kernel_to_mistral in monkey_patch.py. Replaces ALL_ATTENTION_FUNCTIONS["sdpa"] with an SWA-aware wrapper. Tested with Mistral-7B-Instruct-v0.3 end-to-end generation.

What this unblocks

From the roadmap:

  • Mixtral-8x7B (W=4096) - listed as "Help Wanted"
  • Gemma 3 (variable W per layer) - WIP, needs SWA
  • Mistral-7B - has E2E support but currently uses full causal instead of native SWA

CI Configuration

config:
  build: true
  # valid options are "ops", "benchmark", and "sanity"
  test: ["ops", "benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

Copilot AI review requested due to automatic review settings April 17, 2026 03:26
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an experimental sliding-window attention (SWA) prefill operator backed by a fused cuTile kernel, plus test + benchmark scaffolding and dispatch wiring to expose it through tilegym.ops.

Changes:

  • Introduces a cuTile SWA forward kernel + Python launcher, including an optional HuggingFace Mistral monkey-patch wrapper.
  • Adds correctness tests for SWA across multiple sequence/window configurations.
  • Adds a Triton perf_report benchmark and wires the op into the unified tilegym.ops dispatcher and cuTile exports.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/ops/experimental/test_swa_attention.py New SWA correctness tests against a PyTorch reference implementation.
tests/benchmark/experimental/bench_swa_attention.py New Triton perf_report benchmark comparing cuTile SWA vs a PyTorch reference backend.
src/tilegym/ops/ops.py Adds the @dispatch("swa_attention") op API entry point.
src/tilegym/ops/cutile/experimental/swa_attention.py Implements the cuTile SWA prefill kernel + host launcher + HF integration helpers.
src/tilegym/ops/cutile/init.py Exposes the experimental SWA implementation in the cuTile backend package exports.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/tilegym/ops/cutile/experimental/swa_attention.py Outdated
Comment on lines +206 to +241
def swa_fmha_wrapper(module, q, k, v, attention_mask=None, dropout=0.0, scaling=None, is_causal=None, **kwargs):
if scaling is None:
scaling = 1.0 / math.sqrt(q.size(-1))
if is_causal is None:
is_causal = True

# decode (single token) -- our kernel is a prefill kernel, so we
# fall back to PyTorch SDPA for autoregressive decode steps.
# also need to expand KV heads for GQA since SDPA expects matched dims.
if q.size(-2) == 1:
if k.size(1) != q.size(1):
n_rep = q.size(1) // k.size(1)
k = k.repeat_interleave(n_rep, dim=1)
v = v.repeat_interleave(n_rep, dim=1)
# cuDNN backend can fail on some GPUs (e.g. B300), try flash then math
for be in [torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.MATH]:
try:
with torch.nn.attention.sdpa_kernel(be):
o = F.scaled_dot_product_attention(q, k, v, is_causal=False)
return o.transpose(1, 2).contiguous(), None
except RuntimeError:
continue
raise RuntimeError("no working SDPA backend for decode")

# prefill -- try to read window size from the model's config (e.g.
# MistralConfig.sliding_window), fall back to the user-supplied default
w = getattr(getattr(module, "config", None), "sliding_window", None)
if w is None or w is False:
w = window_size
if w is None:
w = k.size(-2) # no window at all, full causal

from tilegym.ops import swa_attention as _swa

o = _swa(q.half(), k.half(), v.half(), window_size=w, scaling=scaling, is_causal=is_causal, backend=backend)
return o.transpose(1, 2).contiguous(), None
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HF wrapper accepts attention_mask and dropout but ignores both in the prefill path, which will produce incorrect results for padded batches and for any training/inference configuration with non-zero dropout. Consider either (a) explicitly rejecting unsupported cases (raise with a clear message when attention_mask is not None or dropout != 0), or (b) falling back to PyTorch SDPA for those cases.

Copilot uses AI. Check for mistakes.
# also need to expand KV heads for GQA since SDPA expects matched dims.
if q.size(-2) == 1:
if k.size(1) != q.size(1):
n_rep = q.size(1) // k.size(1)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the decode path, n_rep = q.size(1) // k.size(1) assumes the head counts are divisible. If they are not, repeat_interleave will either do the wrong thing or error later. Add a divisibility check and raise a clear error when q.size(1) % k.size(1) != 0 (and similarly ensure k.size(1) <= q.size(1)).

Suggested change
n_rep = q.size(1) // k.size(1)
q_heads = q.size(1)
kv_heads = k.size(1)
if kv_heads > q_heads:
raise ValueError(
f"decode path requires k/v head count <= q head count, got q_heads={q_heads}, kv_heads={kv_heads}"
)
if q_heads % kv_heads != 0:
raise ValueError(
f"decode path requires q head count to be divisible by k/v head count, got q_heads={q_heads}, kv_heads={kv_heads}"
)
n_rep = q_heads // kv_heads

Copilot uses AI. Check for mistakes.
import math

import torch
import torch.nn.functional as F
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.nn.functional as F is imported but never used in this benchmark file. Remove the unused import to avoid lint noise and keep the benchmark minimal.

Suggested change
import torch.nn.functional as F

Copilot uses AI. Check for mistakes.
Comment on lines +44 to +47
device = torch.device("cuda")
q = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5)
k = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5)
v = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test unconditionally uses a CUDA device (torch.device("cuda")) but never checks torch.cuda.is_available(). On environments where cuda.tile imports successfully but no CUDA device is available, this will raise before the test can be skipped. Add an explicit CUDA availability skip (as done in other ops tests) before creating CUDA tensors.

Copilot uses AI. Check for mistakes.
Comment on lines +105 to +108
@pytest.mark.parametrize("backend", _backends)
def test_long_context_mistral(self, backend):
# Mistral-style: 8K context, 4K window
self._run_test(B=1, H=1, S=8192, D=128, W=4096, dtype=torch.float16, backend=backend)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_long_context_mistral runs the fp32 reference which materializes an 8192x8192 attention matrix (and does a full QK^T matmul). This is extremely expensive and is likely to cause CI timeouts or OOMs. Consider marking this test as @pytest.mark.slow (or gating it behind an env/marker) and/or reducing the reference workload (e.g., smaller S for unit tests, or a banded O(S*W) reference implementation).

Copilot uses AI. Check for mistakes.
@hannahli-nv
Copy link
Copy Markdown
Collaborator

Hi @DevTechJr, thanks for the contribution! We have received your CLA file and will review your PR later.

@hannahli-nv hannahli-nv force-pushed the experimental/swa_attention_cutile branch from 372232d to 3fe1482 Compare April 17, 2026 03:56
@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test 3fe1482

Comment thread src/tilegym/ops/cutile/experimental/swa_attention.py
Comment thread src/tilegym/ops/cutile/experimental/swa_attention.py Outdated
# -- basic correctness --

@pytest.mark.parametrize("backend", _backends)
def test_window_equals_seq(self, backend):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the test function to start with test_op_xxx

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed function accordingly.



def tile_swa_attention(q, k, v, window_size, scaling=None, is_causal=True, **kwargs):
# q: (B, H, S_Q, D), k/v: (B, H_K, S_K, D) -- fp16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks this func only support fp16. suggest check dtype at beginning:

if q.dtype not in (torch.float16,):
    raise ValueError(f"SWA kernel requires fp16 input, got {q.dtype}")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion. Added the dtype check with ValueError for non-fp16 inputs.

modeling_mistral.apply_rotary_pos_emb = get_apply_rope_func(model="llama")
modeling_mistral.MistralRMSNorm = get_rms_norm_module()
modeling_mistral.MistralMLP = get_swiglu_module()
except Exception:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
except ImportError: pass
except Exception: raise

return torch.matmul(torch.softmax(scores, dim=-1), v.float()).to(q.dtype)


class TestSWAAttention(common.PyTestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this func supports GQA, pls add a test_op_gqa.

@DevTechJr DevTechJr force-pushed the experimental/swa_attention_cutile branch from 56f2859 to dcf7262 Compare April 20, 2026 02:16
@DevTechJr
Copy link
Copy Markdown
Contributor Author

Hi @azazhu , thank you for the comments. I have resolved feedback from both you and Copilot in my new push. Please test and let me know next steps accordingly.

@DevTechJr DevTechJr requested a review from azazhu April 20, 2026 02:26
Comment on lines +37 to +38
("cutile", "CuTile SWA", ("blue", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch ref", ("green", "-")),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @DevTechJr , I notice that the display names here contain spaces ("CuTile SWA", "PyTorch ref"), which breaks the CI summary page. Please match the convention rename to single-token names:

Suggested change
("cutile", "CuTile SWA", ("blue", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch ref", ("green", "-")),
("cutile", "CuTile", ("blue", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch", ("green", "-")),

Sorry for the inconvenience.

@xjmxyt
Copy link
Copy Markdown
Collaborator

xjmxyt commented Apr 23, 2026

/ok to test 9c71b16

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test d58112c

@hannahli-nv hannahli-nv merged commit 337a7d7 into NVIDIA:main Apr 23, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants