Skip to content
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
711e2a92522e0a9921ce58ae658571ca55c49b97
b2c0ea435ece3491b2940af7c08d42974b953e06
17 changes: 0 additions & 17 deletions test/inductor/test_combo_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):

self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)

@requires_cuda
def test_persistent_reduction_no_x_dim(self):
def fn(x, y):
return x.sum(1), y.sum(1)

inps = (
torch.rand(16, 256, device="cuda"),
torch.rand(32, 256, device="cuda"),
)
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
out_eager = fn(*inps)
out_compiled = torch.compile(fn)(*inps)

self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)


@instantiate_parametrized_tests
class ComboKernelDynamicShapesTests(TestCase):
Expand Down
25 changes: 0 additions & 25 deletions test/inductor/test_torchinductor_strided_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,31 +746,6 @@ def test_2d_reduction_odd_shapes(
# Check the code for multiple Rn_BLOCK's
self._assert_reduction_ndims(code, 2)

def test_2d_reduction_no_x_dim(self):
"""
Tests a 2D reduction without an "x" dimension.
"""
# We need a size to get no x dim.
view = self._discontiguous_tensor((2, 346), self.device)

# Expect 1 block pointer for the input.
result, (code,) = self._run_and_compare(
torch.prod,
view,
expected_num_block_pointers=1,
expected_num_triton_kernels=1,
config_patches=tiled_reduction_config,
)

# Check that there's no X dimension in the signature.
(signature_line,) = (
line for line in code.splitlines() if line.startswith("def triton")
)
self.assertNotIn("BLOCK", signature_line)

# Check for 2 reduction dimensions in the body.
self._assert_reduction_ndims(code, 2)

@parametrize(
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",
[
Expand Down
6 changes: 3 additions & 3 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,15 @@ def get_signature_value(idx: int, arg: Any) -> str:
inspect.signature(backend.get_codegen_implementation).parameters
)
if make_ir_sig_params == 2:
ttir_module = src.make_ir(options, context)
ttir_module = src.make_ir(target, options, context)
elif make_ir_sig_params == 3:
codegen_fns = backend.get_codegen_implementation()
ttir_module = src.make_ir(options, codegen_fns, context)
ttir_module = src.make_ir(target, options, codegen_fns, context)
else:
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
codegen_fns = backend.get_codegen_implementation(*codegen_args)
module_map = backend.get_module_map()
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
if not ttir_module.verify():
raise RuntimeError("Verification for TTIR module has failed")

Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
Strangely this is faster than a [1, RBLOCK] block in some cases.

ROCm branch change: Remove want_no_x_dim for persistent reduction.
Inductor benchmarks show no perf advantage and simplifies autotune flow.
"""
return (
features.get_reduction_hint() == ReductionHint.INNER
and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256)
)
return False

@staticmethod
def reduction_split_factor(
Expand Down
13 changes: 5 additions & 8 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def tan(x):
@staticmethod
@maybe_upcast_float32()
def tanh(x):
return f"libdevice.tanh({x})"
return f"libdevice.fast_tanhf({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down Expand Up @@ -1999,13 +1999,10 @@ def should_use_persistent_reduction(self) -> bool:
)

def want_no_x_dim(self):
if (
self.persistent_reduction
and len(self.numels) == self.num_reduction_dims + 1
):
if self.fixed_config:
return self.fixed_config["XBLOCK"] == 1
return V.choices.want_no_x_dim(self.features)
"""
ROCm branch change: Remove want_no_x_dim for persistent reduction.
Inductor benchmarks show no perf advantage and simplifies autotune flow.
"""
return False

@property
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/triton_combo_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def jit_line(
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)
Expand Down
13 changes: 10 additions & 3 deletions torch/_inductor/runtime/coordinate_descent_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import logging
from typing import Callable, Optional, TYPE_CHECKING
from functools import lru_cache

from .hints import TRITON_MAX_BLOCK
from .runtime_utils import red_text, triton_config_to_hashable
Expand Down Expand Up @@ -60,10 +61,16 @@ def get_config_max(self, prefix: str) -> int:
size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None
return min(max_block, size_hint) if size_hint is not None else max_block

@lru_cache(maxsize=1)
def get_warpsmax(self):
# Currently, CUDA has a maximum of 1024 threads, so 32 is the max
# number of warps.
return 1024 // 32
# CUDA/ROCm has a maximum of 1024 threads per block
from torch.cuda import current_device, get_device_properties, is_available

warp_size = (
get_device_properties(current_device()).warp_size if is_available() else 32
)

return 1024 // warp_size

def cache_benchmark_result(self, config, timing):
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
Expand Down
53 changes: 40 additions & 13 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,21 +2689,30 @@ def _persistent_reduction_configs(
xnumel = size_hints["x"]
rnumel = get_total_reduction_numel(size_hints)

MAX_PERSISTENT_BLOCK_NUMEL = 4096
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise")
)

configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
for xblock in (1, 8, 32, 128)
if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096))
]

if "y" not in size_hints:
configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
for xblock in (1, 8, 32, 128)
if xblock == 1
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
or (rnumel * xblock <= 4096 and xblock <= xnumel)
]
else:
configs = []
assert "tiling_scores" in inductor_meta
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
for target_block_size in (1, 8, 32, 64, 128):
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
if target_block_size * rnumel > 4096:
continue

block_sizes = match_target_block_product(
Expand All @@ -2718,19 +2727,28 @@ def _persistent_reduction_configs(
# defer to more autotuning, initially
if "y" in size_hints:
pass
# TODO(jansel): we should be able to improve these heuristics
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
configs = configs[:1]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = [

if not max_autotune_enabled: # Don't filter if tuning enabled
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
configs = configs[:1]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]

if reduction_hint == ReductionHint.OUTER_TINY:
tiny_configs = [
triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
)
]
if max_autotune_enabled:
for tconfig in tiny_configs:
if tconfig not in configs:
configs.append(tconfig)
else:
configs = tiny_configs

for c in configs:
# we don't need Rn_BLOCK for persistent reduction
for prefix in size_hints:
Expand Down Expand Up @@ -2922,20 +2940,29 @@ def user_autotune(
)


def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
if disable_pointwise_autotuning(inductor_meta) and not (
inductor_meta.get("max_autotune") or
inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))

return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
filename=filename,
)


@dataclasses.dataclass
class GridExpr:
"""Generate code for grid size expressions in launcher"""
Expand Down