From b19458a4af43419fce3920745f9e8579abb882ad Mon Sep 17 00:00:00 2001 From: AmdSampsa Date: Tue, 11 Nov 2025 10:13:10 +0000 Subject: [PATCH 1/3] triton sanity check for 2D POI --- torch/_inductor/runtime/triton_heuristics.py | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 1de1f9a595c9e..798ac7e2e840e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2568,17 +2568,19 @@ def pointwise( *hinted_configs, ] if torch.version.hip: - configs += [ # add here - ] - # bypass triton_config_with_settings -> triton_config logic if "x" in size_hints and "y" in size_hints: - configs += [ - Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 - Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 - Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 - Config({"XBLOCK":64, "YBLOCK": 256}, num_warps=4), # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 - Config({"XBLOCK":512, "YBLOCK": 64}, num_warps=8), # wri0: 58us: triton_poi_fused_clone_53 - ] + """add 2D tiling configs, but don't use triton_config_with_settings function + as it is buggy and might change the tiling randomly + """ + def addConfig__(xblock:int, yblock:int, num_warps:int): + # only add a tiling config if size is bigger than the tile + if size_hints["x"] >= xblock and size_hints["y"] >= yblock: + configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps)) + addConfig__(512, 8, 8) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + addConfig__(32, 128, 4) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + addConfig__(64, 32, 8) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + addConfig__(64, 256, 4) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 + addConfig__(512, 64, 8) # wri0: 58us: triton_poi_fused_clone_53 if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): From f8d0c1008738b5ef953ffd410c934fc5bf9c31ae Mon Sep 17 00:00:00 2001 From: AmdSampsa Date: Fri, 14 Nov 2025 10:52:14 +0000 Subject: [PATCH 2/3] blocksize check --- torch/_inductor/runtime/triton_heuristics.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 798ac7e2e840e..6de0e3dfe2244 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2574,8 +2574,19 @@ def pointwise( """ def addConfig__(xblock:int, yblock:int, num_warps:int): # only add a tiling config if size is bigger than the tile - if size_hints["x"] >= xblock and size_hints["y"] >= yblock: - configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps)) + # check also for grid overflow + xgrid = (size_hints["x"] + xblock - 1) // xblock + ygrid = (size_hints["y"] + yblock - 1) // yblock + if xgrid > 2147483647: + return + if ygrid > 65535: + return + if size_hints["x"] < xblock: + return + if size_hints["y"] < yblock: + return + # all good, add the config + configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps)) addConfig__(512, 8, 8) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 addConfig__(32, 128, 4) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 addConfig__(64, 32, 8) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 From a453dea497328e079d4ac90682327de2275a79ca Mon Sep 17 00:00:00 2001 From: AmdSampsa Date: Fri, 14 Nov 2025 11:04:04 +0000 Subject: [PATCH 3/3] blocksize check etc --- torch/_inductor/runtime/triton_heuristics.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 6de0e3dfe2244..18b23a44f6572 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2572,7 +2572,7 @@ def pointwise( """add 2D tiling configs, but don't use triton_config_with_settings function as it is buggy and might change the tiling randomly """ - def addConfig__(xblock:int, yblock:int, num_warps:int): + def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int): # only add a tiling config if size is bigger than the tile # check also for grid overflow xgrid = (size_hints["x"] + xblock - 1) // xblock @@ -2586,12 +2586,12 @@ def addConfig__(xblock:int, yblock:int, num_warps:int): if size_hints["y"] < yblock: return # all good, add the config - configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps)) - addConfig__(512, 8, 8) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 - addConfig__(32, 128, 4) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 - addConfig__(64, 32, 8) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 - addConfig__(64, 256, 4) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 - addConfig__(512, 64, 8) # wri0: 58us: triton_poi_fused_clone_53 + configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps, num_stages=num_stages)) + addConfig__(512, 8, 8,1 ) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + addConfig__(32, 128, 4, 1) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + addConfig__(64, 32, 8, 1) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + addConfig__(64, 256, 4, 1) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 + addConfig__(512, 64, 8, 1) # wri0: 58us: triton_poi_fused_clone_53 if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta):