Skip to content

Commit

Permalink
Pin triton-rocm to latest 2.3.1 commit (pytorch#126309)
Browse files Browse the repository at this point in the history
* Revert "pin rocm"

This reverts commit 45ebb10.

* Revert "lint"

This reverts commit 05860b9.

* rocm_pin
  • Loading branch information
atalman authored May 15, 2024
1 parent 03baf94 commit 0365423
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton-rocm.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d08e16b738ab550c3af51305df624d5c823dc445
c8ad905211f45e162102823149f0d7f2cfaa4418
16 changes: 5 additions & 11 deletions .github/scripts/build_triton_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
SCRIPT_DIR = Path(__file__).parent
REPO_DIR = SCRIPT_DIR.parent.parent

# TODO: Remove me once Triton version is again in sync for vanilla and ROCm
ROCM_TRITION_VERSION = "2.3.0"


def read_triton_pin(rocm_hash: bool = False) -> str:
triton_file = "triton.txt" if not rocm_hash else "triton-rocm.txt"
Expand Down Expand Up @@ -101,12 +98,9 @@ def build_triton(
check_call(["git", "clone", triton_repo], cwd=tmpdir)
if release:
ver, rev, patch = version.split(".")
if build_rocm:
check_call(["git", "checkout", "release/2.3.x"], cwd=triton_basedir)
else:
check_call(
["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir
)
check_call(
["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir
)
else:
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)

Expand Down Expand Up @@ -165,7 +159,7 @@ def build_triton(
patch_init_py(
triton_pythondir / "triton" / "__init__.py",
version=f"{version}",
expected_version=ROCM_TRITION_VERSION if build_rocm else None,
expected_version=None,
)

if build_rocm:
Expand All @@ -174,7 +168,7 @@ def build_triton(
triton_pythondir / "setup.py",
name=triton_pkg_name,
version=f"{version}",
expected_version=ROCM_TRITION_VERSION,
expected_version=None,
)
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
print("ROCm libraries setup for triton installation...")
Expand Down

0 comments on commit 0365423

Please sign in to comment.