From 78f604ae3b97ad1733d89b4ead1b723826fb9191 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 12 Nov 2025 22:07:15 +0000 Subject: [PATCH 1/3] On ROCm, always use fast_tanhf for triton codegen. (cherry picked from commit 7c5277f22a9917902d962f61792e331dfd93cd64) --- torch/_inductor/codegen/triton.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 7d65354e7a2f4..ca68e90f6b1f4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1232,7 +1232,12 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.fast_tanhf({x})" + # On ROCm, always use fast_tanhf + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ + if torch.version.hip: + return f"libdevice.fast_tanhf({x})" + else: + return f"libdevice.tanh({x})" @staticmethod @maybe_upcast_float32() From 084d7b39ee03b12ab04873ab83bd5d270e241f5a Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 13 Nov 2025 01:40:38 +0000 Subject: [PATCH 2/3] Bump up Triton commit to support fast_tanhf. --- .ci/docker/ci_commit_pins/triton.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 6c42823f59a57..a167ee868969a 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -21876a4bbaf371bcb83df8e6ee4f43a92f524dfe +0cace8d2336a9dc399effbb11522eea7f7b8c0b2 From 7cc238e2838296552a9075e186cdbafb4d519346 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 14 Nov 2025 20:27:49 +0000 Subject: [PATCH 3/3] Conditionalize fast_tanhf on triton_version. (cherry picked from commit f416c7119ad1443bf022a37a8f3f21b201aa4bbc) --- torch/_inductor/codegen/triton.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index ca68e90f6b1f4..7f97ea96f2322 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -26,7 +26,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package +from torch.utils._triton import has_triton_package, get_triton_version from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1232,9 +1232,9 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - # On ROCm, always use fast_tanhf - # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ - if torch.version.hip: + if torch.version.hip and get_triton_version() > (3, 2): + # On ROCm, use fast_tanhf depending on Triton version + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ return f"libdevice.fast_tanhf({x})" else: return f"libdevice.tanh({x})"