From f9d62b31b845c16ecb71760e2d19013885b5b6d0 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Mon, 16 Jun 2025 15:14:44 +0100 Subject: [PATCH] Flex_attention logic error resolved num_stages==8 configs are always skipped causing breakages --- torch/_inductor/kernel/flex_attention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 2318be5c423e5..f93334cb4a4d4 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -2273,9 +2273,6 @@ def flex_attention_backward(*args, **kwargs): or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 ): continue - if num_warps == 8: - # Working around https://github.com/pytorch/pytorch/issues/141603 - continue # Performance tuning cur_kernel_options = original_kernel_options.copy()