diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 2318be5c423e5..f93334cb4a4d4 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -2273,9 +2273,6 @@ def flex_attention_backward(*args, **kwargs): or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 ): continue - if num_warps == 8: - # Working around https://github.com/pytorch/pytorch/issues/141603 - continue # Performance tuning cur_kernel_options = original_kernel_options.copy()