diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 951977af602e5..4cb9660ebe84c 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2505,6 +2505,9 @@ def pointwise( triton_config_with_settings( size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 ), + triton_config_with_settings( + size_hints, 4096 # wrt: better than the max_block for some kernel + ), *hinted_configs, ] if len(size_hints) == 2: @@ -2518,10 +2521,12 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs),