From 7c5277f22a9917902d962f61792e331dfd93cd64 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 12 Nov 2025 22:07:15 +0000 Subject: [PATCH 1/3] On ROCm, always use fast_tanhf for triton codegen. --- torch/_inductor/codegen/triton.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e9c5b910ba02f..55e4079b014c6 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -25,7 +25,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package, get_triton_version +from torch.utils._triton import has_triton_package from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1217,11 +1217,10 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - if config.use_fast_math and torch.version.hip: - if get_triton_version() > (3, 4): - return f"libdevice.fast_tanhf({x})" - else: - return f"libdevice.tanh({x})" + # On ROCm, always use fast_tanhf + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ + if torch.version.hip: + return f"libdevice.fast_tanhf({x})" else: return f"libdevice.tanh({x})" From 1b1fde5fcc342c2c0d3c69bf95a91501fc39b324 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 12 Nov 2025 22:20:04 +0000 Subject: [PATCH 2/3] Pump up Triton commit to support fast_tanhf. --- .ci/docker/ci_commit_pins/triton.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 24d633a34eadf..d37a2be2b42af 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -9c7bc0a3d41407bff948b40cd0e9c793147e49bc +80ed7f41e4b5d6e71651847e4725f4e7c2999a08 From f416c7119ad1443bf022a37a8f3f21b201aa4bbc Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 14 Nov 2025 20:27:49 +0000 Subject: [PATCH 3/3] Conditionalize fast_tanhf on triton_version. --- torch/_inductor/codegen/triton.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 55e4079b014c6..43f7c8285cb41 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -25,7 +25,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package +from torch.utils._triton import has_triton_package, get_triton_version from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1217,9 +1217,9 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - # On ROCm, always use fast_tanhf - # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ - if torch.version.hip: + if torch.version.hip and get_triton_version() > (3, 2): + # On ROCm, use fast_tanhf depending on Triton version + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ return f"libdevice.fast_tanhf({x})" else: return f"libdevice.tanh({x})"