diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 24d633a34eadf..d37a2be2b42af 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -9c7bc0a3d41407bff948b40cd0e9c793147e49bc +80ed7f41e4b5d6e71651847e4725f4e7c2999a08 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e9c5b910ba02f..43f7c8285cb41 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1217,11 +1217,10 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - if config.use_fast_math and torch.version.hip: - if get_triton_version() > (3, 4): - return f"libdevice.fast_tanhf({x})" - else: - return f"libdevice.tanh({x})" + if torch.version.hip and get_triton_version() > (3, 2): + # On ROCm, use fast_tanhf depending on Triton version + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ + return f"libdevice.fast_tanhf({x})" else: return f"libdevice.tanh({x})"