From 1321c7dab3ecf14fcdd81dc4781896db46a19b73 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Mon, 4 Aug 2025 19:11:44 +0000 Subject: [PATCH 1/9] add target to make_ir for triton compatibility (cherry picked from commit 7bcbafe925575cf4b2c4d20a86c18e2295e69bc6) --- torch/_higher_order_ops/triton_kernel_wrap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 34a9c5915254..6d4ea64e78aa 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -457,15 +457,15 @@ def get_signature_value(idx: int, arg: Any) -> str: inspect.signature(backend.get_codegen_implementation).parameters ) if make_ir_sig_params == 2: - ttir_module = src.make_ir(options, context) + ttir_module = src.make_ir(target, options, context) elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() - ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module = src.make_ir(target, options, codegen_fns, context) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() - ttir_module = src.make_ir(options, codegen_fns, module_map, context) + ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") From 9b631efb86e8f2e0d5209a1c7faa27e59cb2a968 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Fri, 8 Aug 2025 13:45:32 +0000 Subject: [PATCH 2/9] Nightly triton bump for perf --- .ci/docker/ci_commit_pins/triton.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 60c896b80c8f..21c0f35bff96 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -f7888497a1eb9e98d4c07537f0d0bcfe180d1363 +7a83ab73bd017c4550abd9c41b8176bd63db2858 From e6afdc546aea6ea5eb39c88986ef4388d4a53a2d Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 30 Jul 2025 21:19:57 +0100 Subject: [PATCH 3/9] [release/2.7] [SWDEV-543214] Reland #2416 Fix warps runtime (#2421) Relands https://github.com/ROCm/pytorch/pull/2416 with caching fix Upstream equivalent https://github.com/pytorch/pytorch/pull/159146 --------- Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> (cherry picked from commit f0aebdc31b8a4b7fa4c0b65a1ad6508e5470fe09) --- torch/_inductor/runtime/coordinate_descent_tuner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 413dfaf09d06..6626c88a1e0d 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -3,6 +3,7 @@ import itertools import logging from typing import Callable, Optional, TYPE_CHECKING +from functools import lru_cache from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -60,6 +61,7 @@ def get_config_max(self, prefix: str) -> int: size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None return min(max_block, size_hint) if size_hint is not None else max_block + @lru_cache(maxsize=1) def get_warpsmax(self): # Currently, CUDA has a maximum of 1024 threads, so 32 is the max # number of warps. From ac454885f804c3dd1d3923746bb67d0cb883603c Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Sat, 2 Aug 2025 07:01:33 +0100 Subject: [PATCH 4/9] [release/2.7] [SWDEV-543214] Reland #2416 Fix warps runtime part 2 (#2442) https://github.com/ROCm/pytorch/pull/2421 didn't bring in all required changes to reland https://github.com/ROCm/pytorch/pull/2416 (cherry picked from commit 19431bac8264392b7b02edca06fa489742bf3f49) --- torch/_inductor/runtime/coordinate_descent_tuner.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 6626c88a1e0d..f58f4da06113 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -63,9 +63,14 @@ def get_config_max(self, prefix: str) -> int: @lru_cache(maxsize=1) def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # CUDA/ROCm has a maximum of 1024 threads per block + from torch.cuda import current_device, get_device_properties, is_available + + warp_size = ( + get_device_properties(current_device()).warp_size if is_available() else 32 + ) + + return 1024 // warp_size def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing From 211d11011298c05084b71a524d95c3ed1b13acc5 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 30 Jul 2025 21:24:19 +0100 Subject: [PATCH 5/9] [SWDEV-539215] - Autotune support for persistent reduction and no_x_dim removal (#2417) We noticed persistent reduction kernels can be extremely poor performing https://ontrack-internal.amd.com/browse/SWDEV-539215 The root cause is that in certain size restrictions and kernels "no_x_dim" mode is enabled, which embeds static XBLOCK=1 into the kernel. This means tuning is not optimal. Removing this mode and enabling autotune we achieve 2x performance proving that new heuristics must be made. We will bring this into 2.7 for perf uplift, discussion is undergoing with upstream on removing no_x_dim, if there is no perf regression they are in agreement. Draft PR shows no perf loss on ROCm for any inductor benchmark https://github.com/pytorch/pytorch/pull/159048 Removing tests because no longer relevant. (cherry picked from commit 6c845c6c991e15537a73a31712a87a977094e6d6) --- test/inductor/test_combo_kernels.py | 17 ---------- .../test_torchinductor_strided_blocks.py | 25 -------------- torch/_inductor/choices.py | 8 ++--- torch/_inductor/codegen/triton.py | 11 +++--- torch/_inductor/runtime/triton_heuristics.py | 34 ++++++++++++++----- 5 files changed, 34 insertions(+), 61 deletions(-) diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index a054464bf668..b6f356e25671 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda - def test_persistent_reduction_no_x_dim(self): - def fn(x, y): - return x.sum(1), y.sum(1) - - inps = ( - torch.rand(16, 256, device="cuda"), - torch.rand(32, 256, device="cuda"), - ) - torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) - torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) - out_eager = fn(*inps) - out_compiled = torch.compile(fn)(*inps) - - self.assertEqual(out_eager, out_compiled) - self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @instantiate_parametrized_tests class ComboKernelDynamicShapesTests(TestCase): diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 82bfdd6290bb..2ddb77be3bf0 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -746,31 +746,6 @@ def test_2d_reduction_odd_shapes( # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2) - def test_2d_reduction_no_x_dim(self): - """ - Tests a 2D reduction without an "x" dimension. - """ - # We need a size to get no x dim. - view = self._discontiguous_tensor((2, 346), self.device) - - # Expect 1 block pointer for the input. - result, (code,) = self._run_and_compare( - torch.prod, - view, - expected_num_block_pointers=1, - expected_num_triton_kernels=1, - config_patches=tiled_reduction_config, - ) - - # Check that there's no X dimension in the signature. - (signature_line,) = ( - line for line in code.splitlines() if line.startswith("def triton") - ) - self.assertNotIn("BLOCK", signature_line) - - # Check for 2 reduction dimensions in the body. - self._assert_reduction_ndims(code, 2) - @parametrize( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", [ diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index d79db5f2a053..84041326e4ae 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -202,11 +202,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool: Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. Strangely this is faster than a [1, RBLOCK] block in some cases. + + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. """ - return ( - features.get_reduction_hint() == ReductionHint.INNER - and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256) - ) + return False @staticmethod def reduction_split_factor( diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f8ad32fafc73..726a5a444dac 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1999,13 +1999,10 @@ def should_use_persistent_reduction(self) -> bool: ) def want_no_x_dim(self): - if ( - self.persistent_reduction - and len(self.numels) == self.num_reduction_dims + 1 - ): - if self.fixed_config: - return self.fixed_config["XBLOCK"] == 1 - return V.choices.want_no_x_dim(self.features) + """ + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. + """ return False @property diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ba8de8f9829e..078e65bab0d8 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2689,8 +2689,17 @@ def _persistent_reduction_configs( xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) - MAX_PERSISTENT_BLOCK_NUMEL = 4096 + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ) + configs = [ + triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) + for xblock in (1, 8, 32, 128) + if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) + ] + if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) @@ -2718,19 +2727,28 @@ def _persistent_reduction_configs( # defer to more autotuning, initially if "y" in size_hints: pass - # TODO(jansel): we should be able to improve these heuristics - elif reduction_hint == ReductionHint.INNER and rnumel >= 256: - configs = configs[:1] - elif reduction_hint == ReductionHint.OUTER: - configs = configs[-1:] - elif reduction_hint == ReductionHint.OUTER_TINY: - configs = [ + + if not max_autotune_enabled: # Don't filter if tuning enabled + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + + if reduction_hint == ReductionHint.OUTER_TINY: + tiny_configs = [ triton_config_reduction( size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel, ) ] + if max_autotune_enabled: + for tconfig in tiny_configs: + if tconfig not in configs: + configs.append(tconfig) + else: + configs = tiny_configs + for c in configs: # we don't need Rn_BLOCK for persistent reduction for prefix in size_hints: From d24f08d27581b8c23496bdd13ebcedc470fac64f Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 18 Jul 2025 19:27:51 +0100 Subject: [PATCH 6/9] [SWDEV-539076] Initial naive foreach autotune support (#2377) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds initial autotuning for foreach support required for https://ontrack-internal.amd.com/browse/SWDEV-539076 4x improvement for some kernels Before: triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |   triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |   triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |   After: triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |   triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |   triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |   (cherry picked from commit f07b7f703543935728e311b6435f7ab58da27bab) --- torch/_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index dc2392119cc5..94a905e4211c 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -614,7 +614,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 078e65bab0d8..530a2618dc38 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2940,20 +2940,29 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + if disable_pointwise_autotuning(inductor_meta) and not ( + inductor_meta.get("max_autotune") or + inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, filename=filename, ) - @dataclasses.dataclass class GridExpr: """Generate code for grid size expressions in launcher""" From 6b2c3c54b5c761bba3527674c941b32b105551f8 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Sun, 10 Aug 2025 19:02:25 -0500 Subject: [PATCH 7/9] Update commit pin --- .ci/docker/ci_commit_pins/triton.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 21c0f35bff96..b98b2ef5bc93 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -7a83ab73bd017c4550abd9c41b8176bd63db2858 +b2c0ea435ece3491b2940af7c08d42974b953e06 From a544d27733c9c46b9516e14f5871d13b9c254bb8 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Sun, 10 Aug 2025 19:05:19 -0500 Subject: [PATCH 8/9] Add fast tanh --- torch/_inductor/codegen/triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 726a5a444dac..0e763772911c 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1288,7 +1288,7 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.tanh({x})" + return f"libdevice.fast_tanhf({x})" @staticmethod @maybe_upcast_float32() From 4afc25a5eb6f4600a121ed3af806f3713340c046 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Mon, 11 Aug 2025 16:06:38 +0100 Subject: [PATCH 9/9] Fix --- torch/_inductor/runtime/triton_heuristics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 530a2618dc38..bbe9b04243e6 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2705,14 +2705,14 @@ def _persistent_reduction_configs( triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) for xblock in (1, 8, 32, 128) if xblock == 1 - or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) + or (rnumel * xblock <= 4096 and xblock <= xnumel) ] else: configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} for target_block_size in (1, 8, 32, 64, 128): - if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: + if target_block_size * rnumel > 4096: continue block_sizes = match_target_block_product(