diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index b1bd86590fa08..d743cdb05d34b 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -2476,20 +2476,6 @@ def flex_attention_backward(*args, **kwargs): SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) choices: list[Any] = [] - configs: list[tuple[int, int, int, int]] = [] - configs.append(_get_default_config_bwd(query)) - if config.max_autotune: - num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1] - configs.extend( - [ - (BLOCK1, BLOCK2, w, s) - for BLOCK1 in [32, 64] - for BLOCK2 in [32, 64, 128] - for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) - for s in num_stages_list - if BLOCK2 % BLOCK1 == 0 - ] - ) dtype = query.get_dtype() head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1])