diff --git a/src/tilegym/ops/cutile/__init__.py b/src/tilegym/ops/cutile/__init__.py index dbf0179e..ee821cd1 100644 --- a/src/tilegym/ops/cutile/__init__.py +++ b/src/tilegym/ops/cutile/__init__.py @@ -42,11 +42,13 @@ from .chunk_gated_delta_rule import chunk_gated_delta_rule from .experimental import mhc from .experimental import sparse_mla + from .experimental import swa_attention from .experimental.fused_linear_cross_entropy import fused_linear_cross_entropy from .experimental.mhc import mhc_apply_residual from .experimental.mhc import mhc_gemm_rms_scale from .experimental.mhc import mhc_sinkhorn from .experimental.sparse_mla import tile_sparse_mla + from .experimental.swa_attention import tile_swa_attention from .flash_decode import fmha_decode from .moe import fused_moe_kernel as invoke_fused_moe_kernel from .moe_align_block import moe_align_block_size @@ -99,6 +101,8 @@ "chunk_gated_delta_rule", "recurrent_gated_delta_rule", "sparse_mla", + "swa_attention", + "tile_swa_attention", ] else: __all__ = [] diff --git a/src/tilegym/ops/cutile/experimental/swa_attention.py b/src/tilegym/ops/cutile/experimental/swa_attention.py new file mode 100644 index 00000000..3dfdb383 --- /dev/null +++ b/src/tilegym/ops/cutile/experimental/swa_attention.py @@ -0,0 +1,302 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +# Sliding window attention (SWA) prefill kernel. +# +# Each query attends to at most W keys behind it (plus causal mask). +# Online softmax with exp2 + FTZ for SFU utilization. +# 2D grid: bid(0) = Q tile block, bid(1) = batch*head. +# +# Autotuned on B300: M=64, N=128, occ=2, fast precision. +# Validated with in-house autotuner + NVBench cold measurements. + +import math + +import cuda.tile as ct +import torch +import torch.nn.functional as F + +from tilegym.backend import register_impl +from tilegym.experimental import experimental_kernel + +ConstInt = ct.Constant[int] +ConstBool = ct.Constant[bool] +INV_LOG_2 = 1.0 / math.log(2) # pre-computed for exp2-based softmax +_NEG_INF = -1e30 # sentinel for masked-out attention positions + + +def _cdiv(a, b): + return (a + b - 1) // b + + +# occupancy=2 keeps two CTAs resident per SM for latency hiding. +# the @experimental_kernel tag prints a one-time notice on first launch. +@experimental_kernel +@ct.kernel(occupancy=2) +def swa_fwd_kernel( + Q, + K, + V, + Out, + qk_scale: float, + seq_k: int, + window_size: int, + stride_q: int, # number of Q tiles per head (for flat-buffer indexing) + stride_kv: int, # number of KV tiles per head + TILE_M: ConstInt, # tile rows (query) + TILE_N: ConstInt, # tile cols (key/value) + TILE_D: ConstInt, # head dimension + CAUSAL: ConstBool, +): + # 2D grid: bid(0) iterates over Q-tile blocks, bid(1) over batch*heads + q_block = ct.bid(0) + off_hz = ct.bid(1) + + # map block IDs to flat-buffer offsets so each head's tiles + # don't bleed into the next head's memory region + q_start = q_block * TILE_M + q_offset = off_hz * stride_q + q_block + kv_base = off_hz * stride_kv + + q_tile = ct.load(Q, index=(q_offset, 0), shape=(TILE_M, TILE_D), padding_mode=ct.PaddingMode.ZERO) + + # online softmax running state: max, log-sum-exp, accumulator + m_i = ct.full((TILE_M,), _NEG_INF, dtype=ct.float32) + l_i = ct.zeros((TILE_M,), dtype=ct.float32) + acc = ct.zeros((TILE_M, TILE_D), dtype=ct.float32) + + # convert scale to log2 space so we can use exp2 (maps to SFU hw) + scale_log2 = qk_scale * INV_LOG_2 + offs_m = ct.arange(TILE_M, dtype=ct.int32) + q_start + + # compute the KV block range that intersects with the sliding window. + # this is where the O(S*W) complexity comes from -- we skip blocks + # that are entirely outside the window instead of iterating over all S. + kv_lo = max(0, q_start - window_size + 1) // TILE_N + kv_hi = _cdiv(seq_k, TILE_N) + if CAUSAL: + kv_hi = min(kv_hi, (q_start + TILE_M - 1) // TILE_N + 1) + + for kv_block in range(kv_lo, kv_hi): + kv_start = kv_block * TILE_N + + k_tile = ct.load(K, index=(kv_base + kv_block, 0), shape=(TILE_N, TILE_D), padding_mode=ct.PaddingMode.ZERO) + v_tile = ct.load(V, index=(kv_base + kv_block, 0), shape=(TILE_N, TILE_D), padding_mode=ct.PaddingMode.ZERO) + + # QK^T matmul for this tile pair + qk = ct.mma(q_tile, ct.transpose(k_tile), ct.zeros((TILE_M, TILE_N), dtype=ct.float32)) + + # three-part mask: trailing window, causal upper-triangle, seq bounds + offs_n = ct.arange(TILE_N, dtype=ct.int32) + kv_start + offs_n_2d = ct.expand_dims(offs_n, axis=0) + offs_m_2d = ct.expand_dims(offs_m, axis=1) + + mask = offs_n_2d > (offs_m_2d - window_size) # trailing window + if CAUSAL: + mask = mask & (offs_n_2d <= offs_m_2d) # causal (no future keys) + mask = mask & (offs_n_2d < seq_k) # don't read past actual seq len + + qk = ct.where(mask, qk, ct.full((TILE_M, TILE_N), _NEG_INF, dtype=ct.float32)) + + # online softmax: rescale running state by exp2(old_max - new_max). + # exp2 + flush_to_zero maps directly to the GPU SFU hardware. + m_new = ct.maximum(m_i, ct.max(qk * scale_log2, axis=1)) + alpha = ct.exp2(m_i - m_new, flush_to_zero=True) + p = ct.exp2(qk * scale_log2 - ct.expand_dims(m_new, axis=1), flush_to_zero=True) + + # update running sum and weighted accumulator + l_i = alpha * l_i + ct.sum(p, axis=1) + p_fp16 = ct.astype(p, ct.float16) # downcast for the PV matmul + acc = ct.expand_dims(alpha, axis=1) * acc + ct.mma(p_fp16, v_tile, ct.zeros((TILE_M, TILE_D), dtype=ct.float32)) + m_i = m_new + + # final normalization: divide accumulated values by softmax denominator. + # clamp l_i away from zero so a fully-masked row (all keys outside the + # window) produces zeros rather than NaN. + l_i = ct.maximum(l_i, ct.full((TILE_M,), 1e-6, dtype=ct.float32)) + out = acc / ct.expand_dims(l_i, axis=1) + ct.store(Out, index=(q_offset, 0), tile=ct.astype(out, ct.float16)) + + +# -- host launcher -- + +_DEFAULT_TILE_M = 64 +_DEFAULT_TILE_N = 128 # autotuned: 1.9x faster than N=64 on B300 + + +def tile_swa_attention(q, k, v, window_size, scaling=None, is_causal=True, **kwargs): + # q: (B, H, S_Q, D), k/v: (B, H_K, S_K, D) -- fp16 + if q.dtype not in (torch.float16,): + raise ValueError(f"SWA kernel requires fp16 input, got {q.dtype}") + + B, H, S_Q, D = q.shape + _, H_K, S_K, _ = k.shape + + if scaling is None: + scaling = 1.0 / math.sqrt(D) + if window_size <= 0: + window_size = S_K # non-positive W means full causal + + # expand KV heads for GQA (Mistral uses 8 KV heads for 32 Q heads) + if H_K != H: + if H_K > H or H % H_K != 0: + raise ValueError( + f"Invalid GQA head configuration: query heads H={H} must be an integer multiple of KV heads H_K={H_K}." + ) + kv_repeat = H // H_K + k = k.repeat_interleave(kv_repeat, dim=1) + v = v.repeat_interleave(kv_repeat, dim=1) + + TILE_M = _DEFAULT_TILE_M + TILE_N = _DEFAULT_TILE_N + + # compute strides: how many tiles each head occupies in the flat buffer + stride_q = _cdiv(S_Q, TILE_M) + stride_kv = _cdiv(S_K, TILE_N) + S_Q_padded = stride_q * TILE_M + S_K_padded = stride_kv * TILE_N + + # flatten (B, H, S, D) -> (B*H, S, D) for contiguous tile indexing + q_3d = q.reshape(B * H, S_Q, D) + k_3d = k.reshape(B * H, S_K, D) + v_3d = v.reshape(B * H, S_K, D) + + # pad seq dim to tile boundary so tile loads don't cross head boundaries + if S_Q_padded != S_Q: + q_3d = F.pad(q_3d, (0, 0, 0, S_Q_padded - S_Q)) + if S_K_padded != S_K: + k_3d = F.pad(k_3d, (0, 0, 0, S_K_padded - S_K)) + v_3d = F.pad(v_3d, (0, 0, 0, S_K_padded - S_K)) + + # reshape to (B*H*S_padded, D) -- the kernel indexes this as a 2D tile grid + q_flat = q_3d.reshape(-1, D).contiguous() + k_flat = k_3d.reshape(-1, D).contiguous() + v_flat = v_3d.reshape(-1, D).contiguous() + out_flat = torch.empty_like(q_flat) + + ct.launch( + torch.cuda.current_stream(), + (stride_q, B * H, 1), + swa_fwd_kernel, + ( + q_flat, + k_flat, + v_flat, + out_flat, + scaling, + S_K, + window_size, + stride_q, + stride_kv, + TILE_M, + TILE_N, + D, + is_causal, + ), + ) + + # strip padding and reshape back to (B, H, S_Q, D) + out_3d = out_flat.reshape(B * H, S_Q_padded, D)[:, :S_Q, :] + return out_3d.reshape(B, H, S_Q, D).contiguous() + + +# register as the cutile backend for the "swa_attention" dispatch key +register_impl("swa_attention", backend="cutile")(tile_swa_attention) + + +# -- HuggingFace model integration -- + + +def get_swa_fmha_interface(window_size=4096, backend=None): + """Returns a drop-in replacement for ALL_ATTENTION_FUNCTIONS["sdpa"]. + + Prefill uses the cuTile SWA kernel; decode falls back to SDPA since + our kernel doesn't track absolute position for KV-cache scenarios. + """ + + def swa_fmha_wrapper(module, q, k, v, attention_mask=None, dropout=0.0, scaling=None, is_causal=None, **kwargs): + if scaling is None: + scaling = 1.0 / math.sqrt(q.size(-1)) + if is_causal is None: + is_causal = True + + # decode (single token) -- our kernel is a prefill kernel, so we + # fall back to PyTorch SDPA for autoregressive decode steps. + # also need to expand KV heads for GQA since SDPA expects matched dims. + if q.size(-2) == 1: + if k.size(1) != q.size(1): + q_heads = q.size(1) + kv_heads = k.size(1) + if kv_heads > q_heads or q_heads % kv_heads != 0: + raise ValueError( + f"decode path requires q head count to be a multiple of k/v head count, " + f"got q_heads={q_heads}, kv_heads={kv_heads}" + ) + n_rep = q_heads // kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + # cuDNN backend can fail on some GPUs (e.g. B300), try flash then math + for be in [torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.MATH]: + try: + with torch.nn.attention.sdpa_kernel(be): + o = F.scaled_dot_product_attention(q, k, v, is_causal=False) + return o.transpose(1, 2).contiguous(), None + except RuntimeError: + continue + raise RuntimeError("no working SDPA backend for decode") + + # prefill -- fall back to SDPA for unsupported cases (padded batches + # or training with dropout) since our kernel doesn't handle them. + if attention_mask is not None or dropout != 0.0: + return F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=is_causal + ).transpose(1, 2).contiguous(), None + + # try to read window size from the model's config (e.g. + # MistralConfig.sliding_window), fall back to the user-supplied default + w = getattr(getattr(module, "config", None), "sliding_window", None) + if w is None or w is False: + w = window_size + if w is None: + w = k.size(-2) # no window at all, full causal + + from tilegym.ops import swa_attention as _swa + + o = _swa(q.half(), k.half(), v.half(), window_size=w, scaling=scaling, is_causal=is_causal, backend=backend) + return o.transpose(1, 2).contiguous(), None + + return swa_fmha_wrapper + + +def apply_tilegym_swa_to_mistral(window_size=4096, use_cutile=True): + """Monkey-patch Mistral to route attention through the SWA kernel. + + Call before model creation. Same pattern as TileGym's existing + apply_tilegym_kernel_to_llama / apply_tilegym_kernel_to_mistral. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.mistral import modeling_mistral + + if use_cutile: + from tilegym.backend import set_backend + + set_backend("cutile") + + # replace the default SDPA attention function with our SWA wrapper + ALL_ATTENTION_FUNCTIONS["sdpa"] = get_swa_fmha_interface(window_size=window_size) + + # also patch RoPE, RMSNorm, and SwiGLU if the tilegym kernels are available. + # these are optional -- attention is the main target. + try: + from tilegym.ops import get_apply_rope_func + from tilegym.ops import get_rms_norm_module + from tilegym.ops import get_swiglu_module + + modeling_mistral.apply_rotary_pos_emb = get_apply_rope_func(model="llama") + modeling_mistral.MistralRMSNorm = get_rms_norm_module() + modeling_mistral.MistralMLP = get_swiglu_module() + except ImportError: + pass + except Exception: + raise diff --git a/src/tilegym/ops/ops.py b/src/tilegym/ops/ops.py index 8411e444..3dbce89f 100644 --- a/src/tilegym/ops/ops.py +++ b/src/tilegym/ops/ops.py @@ -955,3 +955,41 @@ def chunk_gated_delta_rule( Tuple[torch.Tensor, Optional[torch.Tensor]]: output (B, T, H, V), final_state """ raise NotImplementedError(f"chunk_gated_delta_rule is not implemented for {get_current_backend()}") + + +@dispatch( + "swa_attention", +) +def swa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + window_size: int, + scaling: Optional[float] = None, + is_causal: bool = True, + **kwargs: Any, +): + """ + Sliding window attention (SWA) forward pass. + + Computes causal attention where each query attends to at most ``window_size`` + preceding keys. Uses online softmax with exp2 + flush-to-zero for efficient + SFU utilization on NVIDIA GPUs. + + Supports grouped-query attention (GQA): when ``k``/``v`` have fewer heads than + ``q``, KV heads are expanded automatically via ``repeat_interleave``. + + Args: + q: Query tensor of shape (B, H, S_Q, D), fp16 + k: Key tensor of shape (B, H_K, S_K, D), fp16 + v: Value tensor of shape (B, H_K, S_K, D), fp16 + window_size: Number of preceding keys each query can attend to. + When window_size >= S_K, equivalent to full causal attention. + scaling: QK scaling factor, defaults to 1/sqrt(D) + is_causal: Whether to apply causal masking (default: True) + **kwargs: Additional backend-specific arguments + + Returns: + Output tensor of shape (B, H, S_Q, D), fp16 + """ + raise NotImplementedError(f"swa_attention is not implemented for {get_current_backend()}") diff --git a/tests/benchmark/experimental/bench_swa_attention.py b/tests/benchmark/experimental/bench_swa_attention.py new file mode 100644 index 00000000..60aebc2b --- /dev/null +++ b/tests/benchmark/experimental/bench_swa_attention.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import math + +import torch +import triton + +import tilegym +from tilegym.backend import is_backend_available +from tilegym.backend import register_impl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def reference_swa(q, k, v, window_size, is_causal=True, scaling=None, **kwargs): + """PyTorch reference: full materialized mask, O(S^2).""" + B, H, S_Q, D = q.shape + S_K = k.shape[2] + if scaling is None: + scaling = 1.0 / math.sqrt(D) + scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * scaling + i = torch.arange(S_Q, device=q.device).unsqueeze(1) + j = torch.arange(S_K, device=q.device).unsqueeze(0) + mask = j > (i - window_size) + if is_causal: + mask = mask & (j <= i) + scores = scores.masked_fill(~mask.unsqueeze(0).unsqueeze(0), float("-inf")) + return torch.matmul(torch.softmax(scores, dim=-1), v.float()).to(q.dtype) + + +register_impl("swa_attention", "torch")(reference_swa) + + +ALL_BACKENDS = [ + ("cutile", "CuTile", ("blue", "-")) if is_backend_available("cutile") else None, + ("torch", "PyTorch", ("green", "-")), +] + + +def get_supported_backends(): + return [p for p in ALL_BACKENDS if p is not None] + + +def create_benchmark_config(dtype): + available_backends = get_supported_backends() + if not available_backends: + return None + backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] + return triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[512, 1024, 2048, 4096, 8192, 16384], + line_arg="backend", + line_vals=list(backends), + line_names=list(names), + styles=list(styles), + ylabel="TFLOPS", + plot_name=f"swa-attention-seq-scaling-{dtype_name}-TFLOPS", + args={ + "dtype": dtype, + "B": 1, + "H": 32, + "D": 128, + "W": 4096, + }, + ) + + +@triton.testing.perf_report([create_benchmark_config(dtype) for dtype in [torch.float16]]) +def bench_swa_attention(seq_len, backend, dtype, B, H, D, W, device=DEVICE): + q = torch.empty(B, H, seq_len, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.3) + k = torch.empty(B, H, seq_len, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.3) + v = torch.empty(B, H, seq_len, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.3) + + eff_w = min(W, seq_len) + + fn = lambda: tilegym.ops.swa_attention( + q, + k, + v, + window_size=eff_w, + is_causal=True, + backend=backend, + ) + + # spot-check correctness at small sizes + if seq_len <= 2048 and backend != "torch": + ref = lambda: reference_swa(q, k, v, window_size=eff_w, is_causal=True) + torch.testing.assert_close(fn(), ref(), atol=5e-2, rtol=1e-2) + + ms = triton.testing.do_bench(fn) + # 2 matmuls per KV block: QK^T and PV, each 2*M*N*K FLOPs + total_flops = 2 * B * H * seq_len * eff_w * D * 2 + return total_flops / (ms * 1e-3) / 1e12 + + +if __name__ == "__main__": + bench_swa_attention.run(print_data=True) diff --git a/tests/ops/experimental/test_swa_attention.py b/tests/ops/experimental/test_swa_attention.py new file mode 100644 index 00000000..07e02586 --- /dev/null +++ b/tests/ops/experimental/test_swa_attention.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import math + +import pytest +import torch + +import tilegym +from tests import common +from tilegym.backend import set_backend + +_backends = ["cutile"] + + +def swa_reference(q, k, v, window_size, is_causal=True, scaling=None): + """Pure PyTorch fp32 reference. Materializes the full SxS mask.""" + B, H, S_Q, D = q.shape + S_K = k.shape[2] + if scaling is None: + scaling = 1.0 / math.sqrt(D) + + scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * scaling + i = torch.arange(S_Q, device=q.device).unsqueeze(1) + j = torch.arange(S_K, device=q.device).unsqueeze(0) + mask = j > (i - window_size) + if is_causal: + mask = mask & (j <= i) + scores = scores.masked_fill(~mask.unsqueeze(0).unsqueeze(0), float("-inf")) + return torch.matmul(torch.softmax(scores, dim=-1), v.float()).to(q.dtype) + + +class TestSWAAttention(common.PyTestCase): + def _run_test(self, B, H, S, D, W, dtype, backend): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if not tilegym.is_backend_available(backend): + pytest.skip(f"Backend {backend} is not available") + try: + set_backend(backend) + except Exception as e: + pytest.skip(f"Backend is not supported: {e}") + self.setUp() + + device = torch.device("cuda") + q = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) + k = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) + v = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) + + def test_fn(q, k, v, window_size, is_causal): + return tilegym.ops.swa_attention(q, k, v, window_size=window_size, is_causal=is_causal) + + self.assertCorrectness( + test_fn, + swa_reference, + kwargs={"q": q, "k": k, "v": v, "window_size": W, "is_causal": True}, + rtol=1e-2, + atol=5e-2, + ) + + # -- basic correctness -- + + @pytest.mark.parametrize("backend", _backends) + def test_op_window_equals_seq(self, backend): + self._run_test(B=1, H=1, S=128, D=128, W=128, dtype=torch.float16, backend=backend) + + @pytest.mark.parametrize("backend", _backends) + def test_op_small_window(self, backend): + self._run_test(B=1, H=1, S=256, D=128, W=128, dtype=torch.float16, backend=backend) + + @pytest.mark.parametrize("backend", _backends) + def test_op_window_of_one(self, backend): + self._run_test(B=1, H=1, S=128, D=128, W=1, dtype=torch.float16, backend=backend) + + @pytest.mark.parametrize("backend", _backends) + def test_op_multi_head(self, backend): + self._run_test(B=2, H=8, S=256, D=128, W=128, dtype=torch.float16, backend=backend) + + # -- edge cases -- + + @pytest.mark.parametrize("backend", _backends) + def test_op_seq_not_divisible_by_tile(self, backend): + self._run_test(B=1, H=1, S=100, D=128, W=64, dtype=torch.float16, backend=backend) + + @pytest.mark.parametrize("backend", _backends) + def test_op_window_equals_tile(self, backend): + self._run_test(B=1, H=1, S=256, D=128, W=64, dtype=torch.float16, backend=backend) + + @pytest.mark.parametrize("backend", _backends) + def test_op_first_tokens(self, backend): + self._run_test(B=1, H=1, S=64, D=128, W=128, dtype=torch.float16, backend=backend) + + @pytest.mark.parametrize("backend", _backends) + def test_op_very_small_seq(self, backend): + self._run_test(B=1, H=1, S=16, D=128, W=8, dtype=torch.float16, backend=backend) + + # -- GQA (Grouped-Query Attention) -- + + @pytest.mark.parametrize("backend", _backends) + def test_op_gqa(self, backend): + """Mistral-style GQA: 32 Q heads, 8 KV heads (4:1 ratio).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if not tilegym.is_backend_available(backend): + pytest.skip(f"Backend {backend} is not available") + try: + set_backend(backend) + except Exception as e: + pytest.skip(f"Backend is not supported: {e}") + self.setUp() + + B, H, H_K, S, D, W = 1, 32, 8, 256, 128, 128 + dtype = torch.float16 + device = torch.device("cuda") + + q = torch.empty(B, H, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) + k = torch.empty(B, H_K, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) + v = torch.empty(B, H_K, S, D, device=device, dtype=dtype).normal_(mean=0.0, std=0.5) + + # the kernel handles GQA expansion internally + out = tilegym.ops.swa_attention(q, k, v, window_size=W, is_causal=True, backend=backend) + + # reference uses pre-expanded KV so it operates on matched head counts + k_exp = k.repeat_interleave(H // H_K, dim=1) + v_exp = v.repeat_interleave(H // H_K, dim=1) + ref = swa_reference(q, k_exp, v_exp, window_size=W, is_causal=True) + + self.assertAllClose(out, ref, rtol=1e-2, atol=5e-2) + + # -- various shapes -- + + @pytest.mark.parametrize( + "backend,S,W", + [(b, s, w) for b in _backends for s, w in [(512, 256), (1024, 512), (2048, 1024), (4096, 2048), (4096, 4096)]], + ) + def test_op_various_configs(self, backend, S, W): + self._run_test(B=1, H=1, S=S, D=128, W=W, dtype=torch.float16, backend=backend) + + @pytest.mark.slow + @pytest.mark.parametrize("backend", _backends) + def test_op_long_context_mistral(self, backend): + # Mistral-style: 8K context, 4K window. + # Marked slow: materializes an 8192x8192 fp32 reference matrix. + self._run_test(B=1, H=1, S=8192, D=128, W=4096, dtype=torch.float16, backend=backend)