diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 6c42823f59a57..a167ee868969a 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -21876a4bbaf371bcb83df8e6ee4f43a92f524dfe +0cace8d2336a9dc399effbb11522eea7f7b8c0b2 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 7d65354e7a2f4..7f97ea96f2322 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -26,7 +26,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package +from torch.utils._triton import has_triton_package, get_triton_version from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1232,7 +1232,12 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.fast_tanhf({x})" + if torch.version.hip and get_triton_version() > (3, 2): + # On ROCm, use fast_tanhf depending on Triton version + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ + return f"libdevice.fast_tanhf({x})" + else: + return f"libdevice.tanh({x})" @staticmethod @maybe_upcast_float32()