From 274659028a68d92d479afd6ff411cb980c8af78a Mon Sep 17 00:00:00 2001 From: AmdSampsa Date: Mon, 13 Oct 2025 17:46:50 +0000 Subject: [PATCH 1/2] wrt2&3 added --- torch/_inductor/runtime/triton_heuristics.py | 46 ++++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 894801d6cffb5..a567a715ad277 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2422,13 +2422,9 @@ def pointwise( ] # Additional reduction configs appended for ROCm builds if torch.version.hip: - configs.append(triton_config_with_settings( - size_hints, - 2048, - num_warps=8, - num_stages=2, - waves_per_eu=1 - )) # 20% improvement + configs += [ + triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1) # 20% improvement # ..where? + ] if len(size_hints) == 2: if ( disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE @@ -2440,17 +2436,24 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), - triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), - triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels - triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more - triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] + if torch.version.hip: + configs += [ # add here + ] + # bypass triton_config_with_settings -> triton_config logic + if "x" in size_hints and "y" in size_hints: + cfg = {"XBLOCK": 32, "YBLOCK": 128} + configs += [ + Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + ] + if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] @@ -2471,6 +2474,12 @@ def pointwise( configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + print() + print("Pointwise will use following configs") + for config in configs: + print(">", config) + print() + return cached_autotune( size_hints, configs, @@ -2583,9 +2592,14 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, wa if torch.version.hip: result_configs.extend([ make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), - make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1) + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), + make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 ]) - + return result_configs @@ -2685,6 +2699,12 @@ def reduction( configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + print() + print("Reduction will use following configs") + for config in configs: + print(">", config) + print() + return cached_autotune( size_hints, configs=configs, From 349eee93d822c7ff92b54843233e8bea20c60594 Mon Sep 17 00:00:00 2001 From: AmdSampsa Date: Tue, 14 Oct 2025 08:06:49 +0000 Subject: [PATCH 2/2] fixed wrt1 configs, added wrt2&3 configs --- torch/_inductor/runtime/triton_heuristics.py | 27 +++++++------------- 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index a567a715ad277..67baa5a5f1eaf 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2415,15 +2415,18 @@ def pointwise( triton_config_with_settings( size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 ), - triton_config_with_settings( - size_hints, 4096 # wrt: better than the max_block for some kernel - ), *hinted_configs, ] # Additional reduction configs appended for ROCm builds if torch.version.hip: configs += [ - triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1) # 20% improvement # ..where? + triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? + triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel + triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 + # triton_poi_fused_index_put_new_zeros_49 + # triton_poi_fused_index_put_new_zeros_54 ] if len(size_hints) == 2: if ( @@ -2437,7 +2440,7 @@ def pointwise( configs = [ triton_config_with_settings(size_hints, 32, 32), triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 - triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), @@ -2448,8 +2451,8 @@ def pointwise( ] # bypass triton_config_with_settings -> triton_config logic if "x" in size_hints and "y" in size_hints: - cfg = {"XBLOCK": 32, "YBLOCK": 128} configs += [ + Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 ] @@ -2474,12 +2477,6 @@ def pointwise( configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) - print() - print("Pointwise will use following configs") - for config in configs: - print(">", config) - print() - return cached_autotune( size_hints, configs, @@ -2699,12 +2696,6 @@ def reduction( configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) - print() - print("Reduction will use following configs") - for config in configs: - print(">", config) - print() - return cached_autotune( size_hints, configs=configs,