diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 695b4a9c865a..ec26d8cbefaf 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -101,8 +101,12 @@ def build_triton( triton_repo = "https://github.com/openai/triton" if device == "rocm": - triton_pkg_name = "pytorch-triton-rocm" triton_repo = "https://github.com/ROCm/triton" + rocm_version = get_rocm_version() # e.g., "7.0.1" + if tuple(map(int, rocm_version.split("."))) > (7, 0, 0): + triton_pkg_name = "triton" + else: + triton_pkg_name = "pytorch-triton-rocm" elif device == "xpu": triton_pkg_name = "pytorch-triton-xpu" triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"