diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 567536db7210..c2ee261f27e7 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -711e2a92522e0a9921ce58ae658571ca55c49b97 +b2c0ea435ece3491b2940af7c08d42974b953e06 \ No newline at end of file diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index a054464bf668..b6f356e25671 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda - def test_persistent_reduction_no_x_dim(self): - def fn(x, y): - return x.sum(1), y.sum(1) - - inps = ( - torch.rand(16, 256, device="cuda"), - torch.rand(32, 256, device="cuda"), - ) - torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) - torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) - out_eager = fn(*inps) - out_compiled = torch.compile(fn)(*inps) - - self.assertEqual(out_eager, out_compiled) - self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @instantiate_parametrized_tests class ComboKernelDynamicShapesTests(TestCase): diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 82bfdd6290bb..2ddb77be3bf0 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -746,31 +746,6 @@ def test_2d_reduction_odd_shapes( # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2) - def test_2d_reduction_no_x_dim(self): - """ - Tests a 2D reduction without an "x" dimension. - """ - # We need a size to get no x dim. - view = self._discontiguous_tensor((2, 346), self.device) - - # Expect 1 block pointer for the input. - result, (code,) = self._run_and_compare( - torch.prod, - view, - expected_num_block_pointers=1, - expected_num_triton_kernels=1, - config_patches=tiled_reduction_config, - ) - - # Check that there's no X dimension in the signature. - (signature_line,) = ( - line for line in code.splitlines() if line.startswith("def triton") - ) - self.assertNotIn("BLOCK", signature_line) - - # Check for 2 reduction dimensions in the body. - self._assert_reduction_ndims(code, 2) - @parametrize( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", [ diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 34a9c5915254..6d4ea64e78aa 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -457,15 +457,15 @@ def get_signature_value(idx: int, arg: Any) -> str: inspect.signature(backend.get_codegen_implementation).parameters ) if make_ir_sig_params == 2: - ttir_module = src.make_ir(options, context) + ttir_module = src.make_ir(target, options, context) elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() - ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module = src.make_ir(target, options, codegen_fns, context) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() - ttir_module = src.make_ir(options, codegen_fns, module_map, context) + ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index d79db5f2a053..84041326e4ae 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -202,11 +202,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool: Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. Strangely this is faster than a [1, RBLOCK] block in some cases. + + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. """ - return ( - features.get_reduction_hint() == ReductionHint.INNER - and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256) - ) + return False @staticmethod def reduction_split_factor( diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f8ad32fafc73..0e763772911c 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1288,7 +1288,7 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.tanh({x})" + return f"libdevice.fast_tanhf({x})" @staticmethod @maybe_upcast_float32() @@ -1999,13 +1999,10 @@ def should_use_persistent_reduction(self) -> bool: ) def want_no_x_dim(self): - if ( - self.persistent_reduction - and len(self.numels) == self.num_reduction_dims + 1 - ): - if self.fixed_config: - return self.fixed_config["XBLOCK"] == 1 - return V.choices.want_no_x_dim(self.features) + """ + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. + """ return False @property diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index dc2392119cc5..94a905e4211c 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -614,7 +614,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 413dfaf09d06..f58f4da06113 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -3,6 +3,7 @@ import itertools import logging from typing import Callable, Optional, TYPE_CHECKING +from functools import lru_cache from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -60,10 +61,16 @@ def get_config_max(self, prefix: str) -> int: size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None return min(max_block, size_hint) if size_hint is not None else max_block + @lru_cache(maxsize=1) def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # CUDA/ROCm has a maximum of 1024 threads per block + from torch.cuda import current_device, get_device_properties, is_available + + warp_size = ( + get_device_properties(current_device()).warp_size if is_available() else 32 + ) + + return 1024 // warp_size def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ba8de8f9829e..bbe9b04243e6 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2689,21 +2689,30 @@ def _persistent_reduction_configs( xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) - MAX_PERSISTENT_BLOCK_NUMEL = 4096 + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ) + configs = [ + triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) + for xblock in (1, 8, 32, 128) + if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096)) + ] + if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) for xblock in (1, 8, 32, 128) if xblock == 1 - or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) + or (rnumel * xblock <= 4096 and xblock <= xnumel) ] else: configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} for target_block_size in (1, 8, 32, 64, 128): - if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: + if target_block_size * rnumel > 4096: continue block_sizes = match_target_block_product( @@ -2718,19 +2727,28 @@ def _persistent_reduction_configs( # defer to more autotuning, initially if "y" in size_hints: pass - # TODO(jansel): we should be able to improve these heuristics - elif reduction_hint == ReductionHint.INNER and rnumel >= 256: - configs = configs[:1] - elif reduction_hint == ReductionHint.OUTER: - configs = configs[-1:] - elif reduction_hint == ReductionHint.OUTER_TINY: - configs = [ + + if not max_autotune_enabled: # Don't filter if tuning enabled + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + + if reduction_hint == ReductionHint.OUTER_TINY: + tiny_configs = [ triton_config_reduction( size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel, ) ] + if max_autotune_enabled: + for tconfig in tiny_configs: + if tconfig not in configs: + configs.append(tconfig) + else: + configs = tiny_configs + for c in configs: # we don't need Rn_BLOCK for persistent reduction for prefix in size_hints: @@ -2922,20 +2940,29 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + if disable_pointwise_autotuning(inductor_meta) and not ( + inductor_meta.get("max_autotune") or + inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, filename=filename, ) - @dataclasses.dataclass class GridExpr: """Generate code for grid size expressions in launcher"""