Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions manywheel/build_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ROCM_VERSION_WITH_PATCH=rocm${ROCM_VERSION_MAJOR}.${ROCM_VERSION_MINOR}.${ROCM_V
ROCM_INT=$(($ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH))

PYTORCH_VERSION=$(cat $PYTORCH_ROOT/version.txt | grep -oP "[0-9]+\.[0-9]+\.[0-9]+")

PYTORCH_VERSION_FULL=$(cat "$PYTORCH_ROOT/version.txt")
do_lightweight_build() {
echo "=== Building LIGHTWEIGHT variant ==="

Expand Down Expand Up @@ -348,25 +348,34 @@ ver() {
# Assuming PYTORCH_VERSION=x.y.z, if x >= 2
if [ ${PYTORCH_VERSION%%\.*} -ge 2 ]; then
if [[ $(uname) == "Linux" ]] && [[ "$DESIRED_PYTHON" != "3.12" || $(ver $PYTORCH_VERSION) -ge $(ver 2.4) ]]; then
# Triton commit got unified in PyTorch 2.5
if [[ $(ver $PYTORCH_VERSION) -ge $(ver 2.5) ]]; then
# Triton commit got unified in PyTorch 2.5
if [[ $(ver $PYTORCH_VERSION) -ge $(ver 2.5) ]]; then
TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
else
else
TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt)
fi
fi
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
# Only linux Python < 3.13 are supported wheels for triton
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'$(if [[ $(ver "$PYTORCH_VERSION") -le $(ver "2.5") ]]; then echo " and python_version < '3.13'"; fi)"

if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="pytorch-triton-rocm==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
# Only linux Python < 3.13 are supported wheels for triton
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'$(if [[ $(ver "$PYTORCH_VERSION") -le $(ver "2.5") ]]; then echo " and python_version < '3.13'"; fi)"
# Use "triton" for dev builds, else "pytorch-triton-rocm"
# Temp: Currently enabling for rocm7.1_internal_testing branch only but plan to expand it to other branches
if [[ "$PYTORCH_VERSION_FULL" == *"2.9.0a0"* ]]; then
PKG="triton"
else
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | pytorch-triton-rocm==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
PKG="pytorch-triton-rocm"
fi

REQ="${PKG}==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"

if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${REQ}"
else
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${REQ}"
fi
unset PKG REQ
fi
fi


echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"

export LIGHTWEIGHT_WHEELNAME_MARKER="${LIGHTWEIGHT_WHEELNAME_MARKER}"
Expand Down
Loading