Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2415,20 +2415,19 @@ def pointwise(
triton_config_with_settings(
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
),
triton_config_with_settings(
size_hints, 4096 # wrt: better than the max_block for some kernel
),
*hinted_configs,
]
# Additional reduction configs appended for ROCm builds
if torch.version.hip:
configs.append(triton_config_with_settings(
size_hints,
2048,
num_warps=8,
num_stages=2,
waves_per_eu=1
)) # 20% improvement
configs += [
triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where?
triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel
triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1),
# -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
# triton_poi_fused_index_put_new_zeros_45
# triton_poi_fused_index_put_new_zeros_49
# triton_poi_fused_index_put_new_zeros_54
]
if len(size_hints) == 2:
if (
disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE
Expand All @@ -2440,17 +2439,24 @@ def pointwise(
else:
configs = [
triton_config_with_settings(size_hints, 32, 32),
triton_config_with_settings(size_hints, 64, 32), # wrt: better for some kernels
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
triton_config_with_settings(size_hints, 256, 16),
triton_config_with_settings(size_hints, 256, 16),
triton_config_with_settings(size_hints, 16, 256),
triton_config_with_settings(size_hints, 128, 16), # wrt: +10% for some kernels
triton_config_with_settings(size_hints, 128, 32), # wrt: ..additional 10% more
triton_config_with_settings(size_hints, 32, 512), # wrt: +30% for some kernels
triton_config_with_settings(size_hints, bs, 1),
triton_config_with_settings(size_hints, 1, bs),
*hinted_configs,
]
if torch.version.hip:
configs += [ # add here
]
# bypass triton_config_with_settings -> triton_config logic
if "x" in size_hints and "y" in size_hints:
configs += [
Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
]

if len(size_hints) == 3:
if disable_pointwise_autotuning(inductor_meta):
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
Expand Down Expand Up @@ -2583,9 +2589,14 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, wa
if torch.version.hip:
result_configs.extend([
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1)
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
])

return result_configs


Expand Down