diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 9039fa2b51a1f..e9c5b910ba02f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -25,7 +25,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package +from torch.utils._triton import has_triton_package, get_triton_version from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1217,7 +1217,13 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.fast_tanhf({x})" + if config.use_fast_math and torch.version.hip: + if get_triton_version() > (3, 4): + return f"libdevice.fast_tanhf({x})" + else: + return f"libdevice.tanh({x})" + else: + return f"libdevice.tanh({x})" @staticmethod @maybe_upcast_float32() diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 1609a3fe77c87..4b6d135f18467 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -61,6 +61,17 @@ def has_triton_tma_device(): return False +@functools.cache +def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]: + try: + import triton # noqa: F401 + + major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2]) + return (major, minor) + except ImportError: + return fallback + + @functools.lru_cache(None) def has_triton() -> bool: if not has_triton_package():