Skip to content

[PyTorch] Support for cuDNN-backed flex attention#2984

Open
vcherepanov-nv wants to merge 8 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3
Open

[PyTorch] Support for cuDNN-backed flex attention#2984
vcherepanov-nv wants to merge 8 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

This PR introduces an alternative, Python-only code path for the FusedAttention backend for PyTorch.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • A new code path for FusedAttention backend, when score_mod (and the related parameters) is specified
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds a Python-only score_mod code path for the cuDNN FusedAttention backend, letting callers inject arbitrary score modification callbacks (causal masking, softcapping, relative-position bias, etc.) directly into cuDNN's SDPA and SDPA-backward graphs without touching any C++ code.

  • New FusedAttentionWithScoreModFunc (torch.autograd.Function) builds and caches cuDNN frontend pygraph objects keyed on tensor shapes/strides and callback identity, then executes them on the current CUDA stream.
  • DotProductAttention.forward and FusedAttention.forward each gain four new optional parameters and an early-exit branch that bypasses the existing backend-selection machinery when score_mod is not None.
  • Tests cover causal masking, softcapping (stateful class-based score_mod), relative-position bias, cache-key stability for bound methods, and version-counter protection against in-place mutations before backward.

Confidence Score: 3/5

Safe to review but not to merge without addressing the callback cache-key correctness issue.

The new FusedAttentionWithScoreModFunc caches compiled cuDNN graphs under a key derived from id() of the score_mod callable. For bound-method score_mods whose instances vary graph topology, a garbage-collected instance can be reallocated at the same address as a new instance of the same class, producing an identical cache key and silently returning a graph built for a different computation. No exception is raised; the wrong attention output is returned and back-propagated.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically _score_mod_callback_cache_key and the module-level _cudnn_score_mod_graph_cache.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds ~670 lines implementing FusedAttentionWithScoreModFunc, cuDNN graph build/cache helpers, and a graph execution path; the id()-based callback cache key is unsafe for parameterized stateful score_mods after GC address reuse, and the cache is unbounded.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Adds score_mod/score_mod_bprop/score_mod_tensors/score_mod_bprop_tensors parameters and an early-exit code path that bypasses normal backend selection; backend probe hard-codes bshd_bshd_bshd layout regardless of actual qkv_format.
tests/pytorch/attention/test_attention.py Adds comprehensive score_mod tests (causal, softcap, relative-position), cache-key stability tests, and an in-place mutation guard test; coverage is thorough for the intended usage patterns.

Sequence Diagram

sequenceDiagram
    participant User
    participant DPA as DotProductAttention.forward
    participant FA as FusedAttention.forward
    participant Func as FusedAttentionWithScoreModFunc
    participant Cache as _cudnn_score_mod_graph_cache
    participant cuDNN as cuDNN Python Frontend

    User->>DPA: "forward(q, k, v, score_mod=fn, ...)"
    DPA->>DPA: validate score_mod constraints
    DPA->>DPA: get_fused_attn_backend (availability check)
    DPA->>FA: "forward(..., score_mod=fn)"
    FA->>FA: assert no FP8 / CP / dropout / masks
    FA->>Func: apply(is_training, q, k, v, fmt, scale, score_mod, ...)
    Func->>Func: allocate output_layer, stats_bhs1
    Func->>Cache: _get_cudnn_score_mod_fwd_graph(key)
    alt cache miss
        Cache->>cuDNN: pygraph(dtype, device, handle)
        cuDNN->>Func: call score_mod(graph, score_tensor, tensors)
        cuDNN->>cuDNN: "sdpa(..., score_mod=wrapped)"
        cuDNN->>cuDNN: validate / build_operation_graph / build_plans
        Cache-->>Func: _CudnnScoreModFwdGraphEntry
    else cache hit
        Cache-->>Func: _CudnnScoreModFwdGraphEntry
    end
    Func->>cuDNN: graph.execute(variant_pack, workspace, handle)
    cuDNN-->>Func: output_layer filled
    Func->>Func: "ctx.save_for_backward(q, k, v, out, stats, *mod_tensors)"
    Func-->>User: output_layer
    User->>Func: backward(d_out)
    Func->>Cache: _get_cudnn_score_mod_bwd_graph(key)
    alt cache miss
        Cache->>cuDNN: pygraph(dtype, device, handle)
        cuDNN->>Func: call score_mod + score_mod_bprop
        cuDNN->>cuDNN: sdpa_backward(...)
        Cache-->>Func: _CudnnScoreModBwdGraphEntry
    else cache hit
        Cache-->>Func: _CudnnScoreModBwdGraphEntry
    end
    Func->>cuDNN: graph.execute(variant_pack, workspace, handle)
    cuDNN-->>Func: dq, dk, dv filled
    Func-->>User: (None, dq, dk, dv, None x8)
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1067-1072 (link)

    P2 Unnecessary tensor allocations during backward graph build

    torch.empty_like(query_layer/key_layer/value_layer) are allocated solely to extract dims and strides for the cuDNN graph, but query_layer, key_layer, and value_layer already have the exact same shapes and strides. Replacing these three empty_like calls with direct use of the already-present tensors avoids three unnecessary GPU allocations during every cache-miss backward-graph build.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +1273 to +1281
def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]:
"""Create a stable cache key for a score_mod callable."""
if callback is None:
return None
self_obj = getattr(callback, "__self__", None)
func_obj = getattr(callback, "__func__", None)
if self_obj is not None and func_obj is not None:
return ("bound_method", id(self_obj), id(func_obj))
return ("callable", id(callback))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 id()-based cache key is unsafe for parameterized bound-method score_mods

id(self_obj) identifies a Python object by its memory address. When a bound-method instance is garbage-collected, Python may immediately reuse that memory for a new instance. If the new instance belongs to the same class (same id(func_obj)), the cache key is identical, so _get_cudnn_score_mod_fwd_graph returns the old compiled graph even though the new instance might construct a structurally different computation — e.g., a score_mod class whose forward loops self.n_layers times. The wrong graph is executed without any error, silently producing incorrect attention outputs.

For stateless module-level functions this is fine (they're never GC'd), but any stateful class-based score_mod where different instances produce different graph topologies can hit this bug in long-running programs. Consider using type(self_obj) and a per-class sequence counter, or requiring callers to provide an explicit cache key.

Comment on lines 91 to 92
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unbounded module-level graph cache will grow indefinitely

_cudnn_score_mod_graph_cache is a plain dict with no eviction policy. Cache keys encode tensor shapes, strides, dtype, and device, so every new (batch, seq, heads, dim) combination — extremely common in training with variable-length sequences or multi-task workloads — inserts a permanent entry. Each cached cuDNN graph holds compiled CUDA kernels and associated state, which can be several tens of MB. Over a long training run this will silently consume increasing GPU/CPU memory. Consider a bounded LRU cache (e.g., functools.lru_cache or a collections.OrderedDict with a size cap).

Comment on lines +1556 to +1563
fused_attention_backend = tex.get_fused_attn_backend(
self.training,
q_type,
q_type,
dpa_utils.QKVLayout["bshd_bshd_bshd"],
dpa_utils.AttnBiasType["no_bias"],
dpa_utils.AttnMaskType["no_mask"],
dpa_utils.SoftmaxType["vanilla"],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 get_fused_attn_backend availability check always uses bshd_bshd_bshd regardless of actual format

The score_mod path hard-codes dpa_utils.QKVLayout["bshd_bshd_bshd"] for the backend probe, even when the user passes qkv_format="sbhd". The result is only used to gate on NVTE_No_Backend, so in practice it likely works today because backend availability for a given dtype is layout-independent. However, if a future cuDNN version makes SBHD/BSHD support diverge, this probe would give a false-positive (accepts sbhd even though no backend supports it) or false-negative (rejects sbhd when it is actually supported). Using the real layout for the probe would make the check self-documenting and future-proof.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant