diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 894801d6cffb5..67baa5a5f1eaf 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2415,20 +2415,19 @@ def pointwise( triton_config_with_settings( size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 ), - triton_config_with_settings( - size_hints, 4096 # wrt: better than the max_block for some kernel - ), *hinted_configs, ] # Additional reduction configs appended for ROCm builds if torch.version.hip: - configs.append(triton_config_with_settings( - size_hints, - 2048, - num_warps=8, - num_stages=2, - waves_per_eu=1 - )) # 20% improvement + configs += [ + triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? + triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel + triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 + # triton_poi_fused_index_put_new_zeros_49 + # triton_poi_fused_index_put_new_zeros_54 + ] if len(size_hints) == 2: if ( disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE @@ -2440,17 +2439,24 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), - triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 - triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), - triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels - triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more - triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] + if torch.version.hip: + configs += [ # add here + ] + # bypass triton_config_with_settings -> triton_config logic + if "x" in size_hints and "y" in size_hints: + configs += [ + Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + ] + if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] @@ -2583,9 +2589,14 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, wa if torch.version.hip: result_configs.extend([ make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), - make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1) + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), + make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 ]) - + return result_configs