From 1d39f46ef86857a59ab6c7095a4018972c34f165 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Tue, 2 Sep 2025 10:02:24 +0000 Subject: [PATCH 1/2] Fix for flex attention tuning --- torch/_inductor/kernel/flex_attention.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index b1bd86590fa08..11c3949e3dc02 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -2476,25 +2476,11 @@ def flex_attention_backward(*args, **kwargs): SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) choices: list[Any] = [] - configs: list[tuple[int, int, int, int]] = [] - configs.append(_get_default_config_bwd(query)) - if config.max_autotune: - num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1] - configs.extend( - [ - (BLOCK1, BLOCK2, w, s) - for BLOCK1 in [32, 64] - for BLOCK2 in [32, 64, 128] - for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) - for s in num_stages_list - if BLOCK2 % BLOCK1 == 0 - ] - ) dtype = query.get_dtype() head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) - + # Default config for warp specialization num_consumer_groups, num_buffers_warp_spec = 0, 0 From 00b714b1aedeecba381a416bdd1c3afc7b2fd233 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 2 Sep 2025 11:05:00 +0100 Subject: [PATCH 2/2] Update flex_attention.py --- torch/_inductor/kernel/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 11c3949e3dc02..d743cdb05d34b 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -2480,7 +2480,7 @@ def flex_attention_backward(*args, **kwargs): dtype = query.get_dtype() head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) - + # Default config for warp specialization num_consumer_groups, num_buffers_warp_spec = 0, 0