Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
21876a4bbaf371bcb83df8e6ee4f43a92f524dfe
0cace8d2336a9dc399effbb11522eea7f7b8c0b2
9 changes: 7 additions & 2 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton_package, get_triton_version

from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
from ...utils._sympy.value_ranges import ValueRanges
Expand Down Expand Up @@ -1232,7 +1232,12 @@ def tan(x):
@staticmethod
@maybe_upcast_float32()
def tanh(x):
return f"libdevice.fast_tanhf({x})"
if torch.version.hip and get_triton_version() > (3, 2):
# On ROCm, use fast_tanhf depending on Triton version
# Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
return f"libdevice.fast_tanhf({x})"
else:
return f"libdevice.tanh({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down