Skip to content

SwinUNETR WindowAttention: add flash attention (scaled_dot_product_attention) #8973

Description

@aymuos15

Is your feature request related to a problem? Please describe.

WindowAttention.forward (monai/networks/nets/swin_unetr.py) computes attention by hand: it materializes the full (nWindows*heads, N, N) score matrix, adds the relative position bias, rebuilds and adds the shifted-window mask (via compute_mask) on every forward, runs softmax, then multiplies by V. Materializing that score matrix is the dominant, memory-bandwidth-bound cost, and constant per-forward state (the mask) is rebuilt each call.

Describe the solution you'd like

Add an opt-in use_flash_attention flag (default False) that routes the windowed attention through torch.nn.functional.scaled_dot_product_attention, folding the relative position bias and, for shifted windows, the attention mask into a single additive attn_mask. The fused kernel skips the score-matrix materialization and does the softmax in-kernel. This mirrors the use_flash_attention option already used in MONAI's SelfAttention, CrossAttention and CABlock. Output matches the current path (float32 verified to ~3e-6, bit-exact in float64), no parameters or buffers change so pretrained weights load unchanged, and the flag is threaded exactly as use_v2 and use_checkpoint are.

Describe alternatives you've considered

Caching the shifted-window mask instead of rebuilding it each forward. That removes only the mask-construction cost, a small fraction of the forward; moving the whole attention to SDPA subsumes it and is the larger win.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions