Is your feature request related to a problem? Please describe.
WindowAttention.forward (monai/networks/nets/swin_unetr.py) computes attention by hand: it materializes the full (nWindows*heads, N, N) score matrix, adds the relative position bias, rebuilds and adds the shifted-window mask (via compute_mask) on every forward, runs softmax, then multiplies by V. Materializing that score matrix is the dominant, memory-bandwidth-bound cost, and constant per-forward state (the mask) is rebuilt each call.
Describe the solution you'd like
Add an opt-in use_flash_attention flag (default False) that routes the windowed attention through torch.nn.functional.scaled_dot_product_attention, folding the relative position bias and, for shifted windows, the attention mask into a single additive attn_mask. The fused kernel skips the score-matrix materialization and does the softmax in-kernel. This mirrors the use_flash_attention option already used in MONAI's SelfAttention, CrossAttention and CABlock. Output matches the current path (float32 verified to ~3e-6, bit-exact in float64), no parameters or buffers change so pretrained weights load unchanged, and the flag is threaded exactly as use_v2 and use_checkpoint are.
Describe alternatives you've considered
Caching the shifted-window mask instead of rebuilding it each forward. That removes only the mask-construction cost, a small fraction of the forward; moving the whole attention to SDPA subsumes it and is the larger win.
Is your feature request related to a problem? Please describe.
WindowAttention.forward(monai/networks/nets/swin_unetr.py) computes attention by hand: it materializes the full(nWindows*heads, N, N)score matrix, adds the relative position bias, rebuilds and adds the shifted-window mask (viacompute_mask) on every forward, runs softmax, then multiplies by V. Materializing that score matrix is the dominant, memory-bandwidth-bound cost, and constant per-forward state (the mask) is rebuilt each call.Describe the solution you'd like
Add an opt-in
use_flash_attentionflag (default False) that routes the windowed attention throughtorch.nn.functional.scaled_dot_product_attention, folding the relative position bias and, for shifted windows, the attention mask into a single additiveattn_mask. The fused kernel skips the score-matrix materialization and does the softmax in-kernel. This mirrors theuse_flash_attentionoption already used in MONAI'sSelfAttention,CrossAttentionandCABlock. Output matches the current path (float32 verified to ~3e-6, bit-exact in float64), no parameters or buffers change so pretrained weights load unchanged, and the flag is threaded exactly asuse_v2anduse_checkpointare.Describe alternatives you've considered
Caching the shifted-window mask instead of rebuilding it each forward. That removes only the mask-construction cost, a small fraction of the forward; moving the whole attention to SDPA subsumes it and is the larger win.