From 33b93456a9c37ca5c914b16236ce009b80c2fd3d Mon Sep 17 00:00:00 2001 From: Vicky Tsang Date: Mon, 25 Nov 2024 17:33:55 +0000 Subject: [PATCH 01/27] [ROCm] Select gpu targets according to PYTORCH_ROCM_ARCH when building AOTriton from source (#139432) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139432 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily Co-authored-by: Vicky Tsang --- cmake/External/aotriton.cmake | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index bc8535a88ef80..fabbe8a48b624 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -1,3 +1,15 @@ +macro(get_target_gpus_from_pytorch target_gpus) + set(gfx90a_key MI200) + set(gfx942_key MI300X) + + foreach(X IN LISTS PYTORCH_ROCM_ARCH) + set(key ${X}) + string(APPEND key "_key") + string(APPEND target_gpus ${${key}}) + string(APPEND target_gpus "|") + endforeach() +endmacro() + if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INCLUDED TRUE) From 885397840b511f12c273d721a594045d4d991a54 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 9 Jan 2025 00:00:02 +0000 Subject: [PATCH 02/27] Let aotriton.cmake detect the best binary package to use, and deprecate aotriton_version.txt (#137443) We do not need `install_aotriton.sh` and `aotriton_version.txt` any more since `aotriton.cmake` now installs the best binary release package as the default option when building pytorch. This should resolve the issue of needing a pre-installed aotriton package when building PyTorch for ROCm from source, which is not feasible if building PyTorch *outside* a CI docker image. With this change, a user can have a pre-installed AOTriton in their environment, if desired, and have the build pick it up by specifying the `AOTRITON_INSTALLED_PREFIX` env var, or have the build automatically detect and install the compatible version. As a third option, the user can also force AOTriton to build from source instead, using the `AOTRITON_INSTALL_FROM_SOURCE` env var. Also, with the changes in this PR, the cmake build process handles the tasks of copying aotriton .so and images directory from `torch/lib` to the installation path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137443 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily Co-authored-by: Jithun Nair --- .ci/docker/libtorch/Dockerfile | 105 ++++++++++++ .ci/docker/manywheel/Dockerfile | 200 ++++++++++++++++++++++ .ci/docker/ubuntu-rocm/Dockerfile | 6 + .ci/manywheel/build_rocm.sh | 268 ++++++++++++++++++++++++++++++ cmake/External/aotriton.cmake | 15 +- 5 files changed, 584 insertions(+), 10 deletions(-) create mode 100644 .ci/docker/libtorch/Dockerfile create mode 100644 .ci/docker/manywheel/Dockerfile create mode 100755 .ci/manywheel/build_rocm.sh diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile new file mode 100644 index 0000000000000..8737d753c9405 --- /dev/null +++ b/.ci/docker/libtorch/Dockerfile @@ -0,0 +1,105 @@ +ARG BASE_TARGET=base +ARG GPU_IMAGE=ubuntu:20.04 +FROM ${GPU_IMAGE} as base + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get clean && apt-get update +RUN apt-get install -y curl locales g++ git-all autoconf automake make cmake wget unzip sudo +# Just add everything as a safe.directory for git since these will be used in multiple places with git +RUN git config --global --add safe.directory '*' + +RUN locale-gen en_US.UTF-8 + +ENV LC_ALL en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US.UTF-8 + +# Install openssl +FROM base as openssl +ADD ./common/install_openssl.sh install_openssl.sh +RUN bash ./install_openssl.sh && rm install_openssl.sh + +# Install python +FROM base as python +ADD common/install_cpython.sh install_cpython.sh +RUN apt-get update -y && \ + apt-get install build-essential gdb lcov libbz2-dev libffi-dev \ + libgdbm-dev liblzma-dev libncurses5-dev libreadline6-dev \ + libsqlite3-dev libssl-dev lzma lzma-dev tk-dev uuid-dev zlib1g-dev -y && \ + bash ./install_cpython.sh && \ + rm install_cpython.sh && \ + apt-get clean + +FROM base as conda +ADD ./common/install_conda_docker.sh install_conda.sh +RUN bash ./install_conda.sh && rm install_conda.sh + +FROM base as cpu +# Install Anaconda +COPY --from=conda /opt/conda /opt/conda +# Install python +COPY --from=python /opt/python /opt/python +COPY --from=python /opt/_internal /opt/_internal +ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH +# Install MKL +ADD ./common/install_mkl.sh install_mkl.sh +RUN bash ./install_mkl.sh && rm install_mkl.sh + +FROM cpu as cuda +ADD ./common/install_cuda.sh install_cuda.sh +ADD ./common/install_magma.sh install_magma.sh +ENV CUDA_HOME /usr/local/cuda + +FROM cuda as cuda11.8 +RUN bash ./install_cuda.sh 11.8 +RUN bash ./install_magma.sh 11.8 +RUN ln -sf /usr/local/cuda-11.8 /usr/local/cuda + +FROM cuda as cuda12.1 +RUN bash ./install_cuda.sh 12.1 +RUN bash ./install_magma.sh 12.1 +RUN ln -sf /usr/local/cuda-12.1 /usr/local/cuda + +FROM cuda as cuda12.4 +RUN bash ./install_cuda.sh 12.4 +RUN bash ./install_magma.sh 12.4 +RUN ln -sf /usr/local/cuda-12.4 /usr/local/cuda + +FROM cuda as cuda12.6 +RUN bash ./install_cuda.sh 12.6 +RUN bash ./install_magma.sh 12.6 +RUN ln -sf /usr/local/cuda-12.6 /usr/local/cuda + +FROM cpu as rocm +ARG PYTORCH_ROCM_ARCH +ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} +ENV MKLROOT /opt/intel +# Adding ROCM_PATH env var so that LoadHip.cmake (even with logic updated for ROCm6.0) +# find HIP works for ROCm5.7. Not needed for ROCm6.0 and above. +# Remove below when ROCm5.7 is not in support matrix anymore. +ENV ROCM_PATH /opt/rocm +# No need to install ROCm as base docker image should have full ROCm install +#ADD ./common/install_rocm.sh install_rocm.sh +ADD ./common/install_rocm_drm.sh install_rocm_drm.sh +ADD ./common/install_rocm_magma.sh install_rocm_magma.sh +# gfortran and python needed for building magma from source for ROCm +RUN apt-get update -y && \ + apt-get install gfortran -y && \ + apt-get install python -y && \ + apt-get clean + +RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh +RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh + +FROM ${BASE_TARGET} as final +COPY --from=openssl /opt/openssl /opt/openssl +# Install patchelf +ADD ./common/install_patchelf.sh install_patchelf.sh +RUN bash ./install_patchelf.sh && rm install_patchelf.sh +# Install Anaconda +COPY --from=conda /opt/conda /opt/conda +# Install python +COPY --from=python /opt/python /opt/python +COPY --from=python /opt/_internal /opt/_internal +ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH diff --git a/.ci/docker/manywheel/Dockerfile b/.ci/docker/manywheel/Dockerfile new file mode 100644 index 0000000000000..04298fd0ed023 --- /dev/null +++ b/.ci/docker/manywheel/Dockerfile @@ -0,0 +1,200 @@ +# syntax = docker/dockerfile:experimental +ARG ROCM_VERSION=3.7 +ARG BASE_CUDA_VERSION=11.8 + +ARG GPU_IMAGE=centos:7 +FROM centos:7 as base + +ENV LC_ALL en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US.UTF-8 + +ARG DEVTOOLSET_VERSION=9 + +# Note: This is required patch since CentOS have reached EOL +# otherwise any yum install setp will fail +RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo +RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo +RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo +RUN yum install -y wget curl perl util-linux xz bzip2 git patch which perl zlib-devel +# Just add everything as a safe.directory for git since these will be used in multiple places with git +RUN git config --global --add safe.directory '*' +RUN yum install -y yum-utils centos-release-scl +RUN yum-config-manager --enable rhel-server-rhscl-7-rpms +# Note: After running yum-config-manager --enable rhel-server-rhscl-7-rpms +# patch is required once again. Somehow this steps adds mirror.centos.org +RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo +RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo +RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo +RUN yum install -y devtoolset-${DEVTOOLSET_VERSION}-gcc devtoolset-${DEVTOOLSET_VERSION}-gcc-c++ devtoolset-${DEVTOOLSET_VERSION}-gcc-gfortran devtoolset-${DEVTOOLSET_VERSION}-binutils +ENV PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH + +RUN yum --enablerepo=extras install -y epel-release + +# cmake-3.18.4 from pip +RUN yum install -y python3-pip && \ + python3 -mpip install cmake==3.18.4 && \ + ln -s /usr/local/bin/cmake /usr/bin/cmake + +RUN yum install -y autoconf aclocal automake make sudo + +FROM base as openssl +# Install openssl (this must precede `build python` step) +# (In order to have a proper SSL module, Python is compiled +# against a recent openssl [see env vars above], which is linked +# statically. We delete openssl afterwards.) +ADD ./common/install_openssl.sh install_openssl.sh +RUN bash ./install_openssl.sh && rm install_openssl.sh + +# EPEL for cmake +FROM base as patchelf +# Install patchelf +ADD ./common/install_patchelf.sh install_patchelf.sh +RUN bash ./install_patchelf.sh && rm install_patchelf.sh +RUN cp $(which patchelf) /patchelf + +FROM patchelf as python +# build python +COPY manywheel/build_scripts /build_scripts +ADD ./common/install_cpython.sh /build_scripts/install_cpython.sh +RUN bash build_scripts/build.sh && rm -r build_scripts + +FROM base as cuda +ARG BASE_CUDA_VERSION=10.2 +# Install CUDA +ADD ./common/install_cuda.sh install_cuda.sh +RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh + +FROM base as intel +# MKL +ADD ./common/install_mkl.sh install_mkl.sh +RUN bash ./install_mkl.sh && rm install_mkl.sh + +FROM base as magma +ARG BASE_CUDA_VERSION=10.2 +# Install magma +ADD ./common/install_magma.sh install_magma.sh +RUN bash ./install_magma.sh ${BASE_CUDA_VERSION} && rm install_magma.sh + +FROM base as jni +# Install java jni header +ADD ./common/install_jni.sh install_jni.sh +ADD ./java/jni.h jni.h +RUN bash ./install_jni.sh && rm install_jni.sh + +FROM base as libpng +# Install libpng +ADD ./common/install_libpng.sh install_libpng.sh +RUN bash ./install_libpng.sh && rm install_libpng.sh + +FROM ${GPU_IMAGE} as common +RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo +RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo +RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo +ENV LC_ALL en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US.UTF-8 +RUN yum install -y \ + aclocal \ + autoconf \ + automake \ + bison \ + bzip2 \ + curl \ + diffutils \ + file \ + git \ + make \ + patch \ + perl \ + unzip \ + util-linux \ + wget \ + which \ + xz \ + yasm +RUN yum install -y \ + https://repo.ius.io/ius-release-el7.rpm \ + https://ossci-linux.s3.amazonaws.com/epel-release-7-14.noarch.rpm + +RUN yum swap -y git git236-core +# git236+ would refuse to run git commands in repos owned by other users +# Which causes version check to fail, as pytorch repo is bind-mounted into the image +# Override this behaviour by treating every folder as safe +# For more details see https://github.com/pytorch/pytorch/issues/78659#issuecomment-1144107327 +RUN git config --global --add safe.directory "*" + +ENV SSL_CERT_FILE=/opt/_internal/certs.pem +# Install LLVM version +COPY --from=openssl /opt/openssl /opt/openssl +COPY --from=python /opt/python /opt/python +COPY --from=python /opt/_internal /opt/_internal +COPY --from=python /opt/python/cp39-cp39/bin/auditwheel /usr/local/bin/auditwheel +COPY --from=intel /opt/intel /opt/intel +COPY --from=patchelf /usr/local/bin/patchelf /usr/local/bin/patchelf +COPY --from=jni /usr/local/include/jni.h /usr/local/include/jni.h +COPY --from=libpng /usr/local/bin/png* /usr/local/bin/ +COPY --from=libpng /usr/local/bin/libpng* /usr/local/bin/ +COPY --from=libpng /usr/local/include/png* /usr/local/include/ +COPY --from=libpng /usr/local/include/libpng* /usr/local/include/ +COPY --from=libpng /usr/local/lib/libpng* /usr/local/lib/ +COPY --from=libpng /usr/local/lib/pkgconfig /usr/local/lib/pkgconfig + +FROM common as cpu_final +ARG BASE_CUDA_VERSION=10.1 +ARG DEVTOOLSET_VERSION=9 +# Install Anaconda +ADD ./common/install_conda_docker.sh install_conda.sh +RUN bash ./install_conda.sh && rm install_conda.sh +ENV PATH /opt/conda/bin:$PATH +RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo +RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo +RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo + +RUN yum install -y yum-utils centos-release-scl +RUN yum-config-manager --enable rhel-server-rhscl-7-rpms +RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo +RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo +RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo +RUN yum install -y devtoolset-${DEVTOOLSET_VERSION}-gcc devtoolset-${DEVTOOLSET_VERSION}-gcc-c++ devtoolset-${DEVTOOLSET_VERSION}-gcc-gfortran devtoolset-${DEVTOOLSET_VERSION}-binutils +ENV PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH + +# cmake is already installed inside the rocm base image, so remove if present +RUN rpm -e cmake || true +# cmake-3.18.4 from pip +RUN yum install -y python3-pip && \ + python3 -mpip install cmake==3.18.4 && \ + ln -s /usr/local/bin/cmake /usr/bin/cmake + +# ninja +RUN yum install -y ninja-build + +FROM cpu_final as cuda_final +RUN rm -rf /usr/local/cuda-${BASE_CUDA_VERSION} +COPY --from=cuda /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION} +COPY --from=magma /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION} +RUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda +ENV PATH=/usr/local/cuda/bin:$PATH + +FROM cpu_final as rocm_final +ARG ROCM_VERSION=3.7 +ARG PYTORCH_ROCM_ARCH +ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} +# Adding ROCM_PATH env var so that LoadHip.cmake (even with logic updated for ROCm6.0) +# find HIP works for ROCm5.7. Not needed for ROCm6.0 and above. +# Remove below when ROCm5.7 is not in support matrix anymore. +ENV ROCM_PATH /opt/rocm +ENV MKLROOT /opt/intel +# No need to install ROCm as base docker image should have full ROCm install +#ADD ./common/install_rocm.sh install_rocm.sh +#RUN ROCM_VERSION=${ROCM_VERSION} bash ./install_rocm.sh && rm install_rocm.sh +ADD ./common/install_rocm_drm.sh install_rocm_drm.sh +RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh +# cmake3 is needed for the MIOpen build +RUN ln -sf /usr/local/bin/cmake /usr/bin/cmake3 +ADD ./common/install_rocm_magma.sh install_rocm_magma.sh +RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh +ADD ./common/install_miopen.sh install_miopen.sh +RUN bash ./install_miopen.sh ${ROCM_VERSION} && rm install_miopen.sh diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 9de8423640104..e69d64cc5a6b7 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -106,6 +106,12 @@ COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +# This is needed by sccache +COPY ./common/install_openssl.sh install_openssl.sh +ENV OPENSSL_ROOT_DIR /opt/openssl +RUN bash ./install_openssl.sh +ENV OPENSSL_DIR /opt/openssl + # Install ccache/sccache (do this last, so we get priority in PATH) COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh new file mode 100755 index 0000000000000..703248d44aa91 --- /dev/null +++ b/.ci/manywheel/build_rocm.sh @@ -0,0 +1,268 @@ +#!/usr/bin/env bash + +set -ex + +export ROCM_HOME=/opt/rocm +export MAGMA_HOME=$ROCM_HOME/magma +# TODO: libtorch_cpu.so is broken when building with Debug info +export BUILD_DEBUG_INFO=0 + +# TODO Are these all used/needed? +export TH_BINARY_BUILD=1 +export USE_STATIC_CUDNN=1 +export USE_STATIC_NCCL=1 +export ATEN_STATIC_CUDA=1 +export USE_CUDA_STATIC_LINK=1 +export INSTALL_TEST=0 # dont install test binaries into site-packages +# Set RPATH instead of RUNPATH when using patchelf to avoid LD_LIBRARY_PATH override +export FORCE_RPATH="--force-rpath" + +# Keep an array of cmake variables to add to +if [[ -z "$CMAKE_ARGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build() + CMAKE_ARGS=() +fi +if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build_caffe2() + EXTRA_CAFFE2_CMAKE_FLAGS=() +fi + +# Determine ROCm version and architectures to build for +# +# NOTE: We should first check `DESIRED_CUDA` when determining `ROCM_VERSION` +if [[ -n "$DESIRED_CUDA" ]]; then + if ! echo "${DESIRED_CUDA}"| grep "^rocm" >/dev/null 2>/dev/null; then + export DESIRED_CUDA="rocm${DESIRED_CUDA}" + fi + # rocm3.7, rocm3.5.1 + ROCM_VERSION="$DESIRED_CUDA" + echo "Using $ROCM_VERSION as determined by DESIRED_CUDA" +else + echo "Must set DESIRED_CUDA" + exit 1 +fi + +# Package directories +WHEELHOUSE_DIR="wheelhouse$ROCM_VERSION" +LIBTORCH_HOUSE_DIR="libtorch_house$ROCM_VERSION" +if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then + if [[ -z "$BUILD_PYTHONLESS" ]]; then + PYTORCH_FINAL_PACKAGE_DIR="/remote/wheelhouse$ROCM_VERSION" + else + PYTORCH_FINAL_PACKAGE_DIR="/remote/libtorch_house$ROCM_VERSION" + fi +fi +mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" || true + +# To make version comparison easier, create an integer representation. +ROCM_VERSION_CLEAN=$(echo ${ROCM_VERSION} | sed s/rocm//) +save_IFS="$IFS" +IFS=. ROCM_VERSION_ARRAY=(${ROCM_VERSION_CLEAN}) +IFS="$save_IFS" +if [[ ${#ROCM_VERSION_ARRAY[@]} == 2 ]]; then + ROCM_VERSION_MAJOR=${ROCM_VERSION_ARRAY[0]} + ROCM_VERSION_MINOR=${ROCM_VERSION_ARRAY[1]} + ROCM_VERSION_PATCH=0 +elif [[ ${#ROCM_VERSION_ARRAY[@]} == 3 ]]; then + ROCM_VERSION_MAJOR=${ROCM_VERSION_ARRAY[0]} + ROCM_VERSION_MINOR=${ROCM_VERSION_ARRAY[1]} + ROCM_VERSION_PATCH=${ROCM_VERSION_ARRAY[2]} +else + echo "Unhandled ROCM_VERSION ${ROCM_VERSION}" + exit 1 +fi +ROCM_INT=$(($ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH)) + +# Required ROCm libraries +ROCM_SO_FILES=( + "libMIOpen.so" + "libamdhip64.so" + "libhipblas.so" + "libhipfft.so" + "libhiprand.so" + "libhipsolver.so" + "libhipsparse.so" + "libhsa-runtime64.so" + "libamd_comgr.so" + "libmagma.so" + "librccl.so" + "librocblas.so" + "librocfft.so" + "librocm_smi64.so" + "librocrand.so" + "librocsolver.so" + "librocsparse.so" + "libroctracer64.so" + "libroctx64.so" + "libhipblaslt.so" + "libhiprtc.so" +) + +if [[ $ROCM_INT -ge 60100 ]]; then + ROCM_SO_FILES+=("librocprofiler-register.so") +fi + +if [[ $ROCM_INT -ge 60200 ]]; then + ROCM_SO_FILES+=("librocm-core.so") +fi + +OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release` +if [[ "$OS_NAME" == *"CentOS Linux"* || "$OS_NAME" == *"AlmaLinux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" + LIBNUMA_PATH="/usr/lib64/libnuma.so.1" + LIBELF_PATH="/usr/lib64/libelf.so.1" + if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then + LIBTINFO_PATH="/usr/lib64/libtinfo.so.5" + else + LIBTINFO_PATH="/usr/lib64/libtinfo.so.6" + fi + LIBDRM_PATH="/opt/amdgpu/lib64/libdrm.so.2" + LIBDRM_AMDGPU_PATH="/opt/amdgpu/lib64/libdrm_amdgpu.so.1" + if [[ $ROCM_INT -ge 60100 && $ROCM_INT -lt 60300 ]]; then + # Below libs are direct dependencies of libhipsolver + LIBSUITESPARSE_CONFIG_PATH="/lib64/libsuitesparseconfig.so.4" + if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then + LIBCHOLMOD_PATH="/lib64/libcholmod.so.2" + # Below libs are direct dependencies of libsatlas + LIBGFORTRAN_PATH="/lib64/libgfortran.so.3" + else + LIBCHOLMOD_PATH="/lib64/libcholmod.so.3" + # Below libs are direct dependencies of libsatlas + LIBGFORTRAN_PATH="/lib64/libgfortran.so.5" + fi + # Below libs are direct dependencies of libcholmod + LIBAMD_PATH="/lib64/libamd.so.2" + LIBCAMD_PATH="/lib64/libcamd.so.2" + LIBCCOLAMD_PATH="/lib64/libccolamd.so.2" + LIBCOLAMD_PATH="/lib64/libcolamd.so.2" + LIBSATLAS_PATH="/lib64/atlas/libsatlas.so.3" + # Below libs are direct dependencies of libsatlas + LIBQUADMATH_PATH="/lib64/libquadmath.so.0" + fi + MAYBE_LIB64=lib64 +elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then + LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1" + LIBNUMA_PATH="/usr/lib/x86_64-linux-gnu/libnuma.so.1" + LIBELF_PATH="/usr/lib/x86_64-linux-gnu/libelf.so.1" + if [[ $ROCM_INT -ge 50300 ]]; then + LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.6" + else + LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.5" + fi + LIBDRM_PATH="/usr/lib/x86_64-linux-gnu/libdrm.so.2" + LIBDRM_AMDGPU_PATH="/usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1" + if [[ $ROCM_INT -ge 60100 && $ROCM_INT -lt 60300 ]]; then + # Below libs are direct dependencies of libhipsolver + LIBCHOLMOD_PATH="/lib/x86_64-linux-gnu/libcholmod.so.3" + # Below libs are direct dependencies of libcholmod + LIBSUITESPARSE_CONFIG_PATH="/lib/x86_64-linux-gnu/libsuitesparseconfig.so.5" + LIBAMD_PATH="/lib/x86_64-linux-gnu/libamd.so.2" + LIBCAMD_PATH="/lib/x86_64-linux-gnu/libcamd.so.2" + LIBCCOLAMD_PATH="/lib/x86_64-linux-gnu/libccolamd.so.2" + LIBCOLAMD_PATH="/lib/x86_64-linux-gnu/libcolamd.so.2" + LIBMETIS_PATH="/lib/x86_64-linux-gnu/libmetis.so.5" + LIBLAPACK_PATH="/lib/x86_64-linux-gnu/liblapack.so.3" + LIBBLAS_PATH="/lib/x86_64-linux-gnu/libblas.so.3" + # Below libs are direct dependencies of libblas + LIBGFORTRAN_PATH="/lib/x86_64-linux-gnu/libgfortran.so.5" + LIBQUADMATH_PATH="/lib/x86_64-linux-gnu/libquadmath.so.0" + fi + MAYBE_LIB64=lib +fi +OS_SO_PATHS=($LIBGOMP_PATH $LIBNUMA_PATH\ + $LIBELF_PATH $LIBTINFO_PATH\ + $LIBDRM_PATH $LIBDRM_AMDGPU_PATH\ + $LIBSUITESPARSE_CONFIG_PATH\ + $LIBCHOLMOD_PATH $LIBAMD_PATH\ + $LIBCAMD_PATH $LIBCCOLAMD_PATH\ + $LIBCOLAMD_PATH $LIBSATLAS_PATH\ + $LIBGFORTRAN_PATH $LIBQUADMATH_PATH\ + $LIBMETIS_PATH $LIBLAPACK_PATH\ + $LIBBLAS_PATH) +OS_SO_FILES=() +for lib in "${OS_SO_PATHS[@]}" +do + file_name="${lib##*/}" # Substring removal of path to get filename + OS_SO_FILES[${#OS_SO_FILES[@]}]=$file_name # Append lib to array +done + +# rocBLAS library files +ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library +ROCBLAS_LIB_DST=lib/rocblas/library +ARCH=$(echo $PYTORCH_ROCM_ARCH | sed 's/;/|/g') # Replace ; seperated arch list to bar for grep +ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH) +OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx) +ROCBLAS_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) + +# hipblaslt library files +HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library +HIPBLASLT_LIB_DST=lib/hipblaslt/library +ARCH_SPECIFIC_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -E $ARCH) +OTHER_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -v gfx) +HIPBLASLT_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) + +# ROCm library files +ROCM_SO_PATHS=() +for lib in "${ROCM_SO_FILES[@]}" +do + file_path=($(find $ROCM_HOME/lib/ -name "$lib")) # First search in lib + if [[ -z $file_path ]]; then + if [ -d "$ROCM_HOME/lib64/" ]; then + file_path=($(find $ROCM_HOME/lib64/ -name "$lib")) # Then search in lib64 + fi + fi + if [[ -z $file_path ]]; then + file_path=($(find $ROCM_HOME/ -name "$lib")) # Then search in ROCM_HOME + fi + if [[ -z $file_path ]]; then + echo "Error: Library file $lib is not found." >&2 + exit 1 + fi + ROCM_SO_PATHS[${#ROCM_SO_PATHS[@]}]="$file_path" # Append lib to array +done + +DEPS_LIST=( + ${ROCM_SO_PATHS[*]} + ${OS_SO_PATHS[*]} +) + +DEPS_SONAME=( + ${ROCM_SO_FILES[*]} + ${OS_SO_FILES[*]} +) + +DEPS_AUX_SRCLIST=( + "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_SRC/}" + "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_SRC/}" + "/opt/amdgpu/share/libdrm/amdgpu.ids" +) + +DEPS_AUX_DSTLIST=( + "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_DST/}" + "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_DST/}" + "share/libdrm/amdgpu.ids" +) + +# MIOpen library files +MIOPEN_SHARE_SRC=$ROCM_HOME/share/miopen/db +MIOPEN_SHARE_DST=share/miopen/db +MIOPEN_SHARE_FILES=($(ls $MIOPEN_SHARE_SRC | grep -E $ARCH)) +DEPS_AUX_SRCLIST+=(${MIOPEN_SHARE_FILES[@]/#/$MIOPEN_SHARE_SRC/}) +DEPS_AUX_DSTLIST+=(${MIOPEN_SHARE_FILES[@]/#/$MIOPEN_SHARE_DST/}) + +# RCCL library files +RCCL_SHARE_SRC=$ROCM_HOME/share/rccl/msccl-algorithms +RCCL_SHARE_DST=share/rccl/msccl-algorithms +RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC)) +DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/}) +DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/}) + +echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}" + +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +if [[ -z "$BUILD_PYTHONLESS" ]]; then + BUILD_SCRIPT=build_common.sh +else + BUILD_SCRIPT=build_libtorch.sh +fi +source $SCRIPTPATH/${BUILD_SCRIPT} diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index fabbe8a48b624..47b8874c7ca26 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -1,6 +1,7 @@ macro(get_target_gpus_from_pytorch target_gpus) set(gfx90a_key MI200) set(gfx942_key MI300X) + set(gfx1100_key Navi31) foreach(X IN LISTS PYTORCH_ROCM_ARCH) set(key ${X}) @@ -21,25 +22,19 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.7.1b") + set(__AOTRITON_VER "0.8b") set(__AOTRITON_MANYLINUX_LIST - "manylinux_2_17" # rocm6.1 - "manylinux_2_17" # rocm6.2 "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 ) set(__AOTRITON_ROCM_LIST - "rocm6.1" - "rocm6.2" "rocm6.2" "rocm6.3" ) - set(__AOTRITON_CI_COMMIT "f6b28a9b7265b69e3df54ea6ba0237e8a8d6f736") + set(__AOTRITON_CI_COMMIT "6f8cbcac8a92775291bb1ba8f514d4beb350baf4") set(__AOTRITON_SHA256_LIST - "4f73c9271f95d18c1ef0d824bb6ca0ac63fe7795cfe786ffe4964287be5ecff2" # rocm6.1 - "df00412ae36fe5732d0a4601802bd3622b5dec12df7ec86027c5147adeb54c25" # rocm6.2 - "852d0e6e280cee3256fc5c7c3abed657594d7f56081d768ff8616c08bf9098b2" # rocm6.2 - "e4e3b06d2431e68e0096fcc8d3668cd5034ca0fd6fe236fb3b96774427d934b8" # rocm6.3 + "e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8" # rocm6.2 + "dc03d531ca399250b7d2fbdfa61929871788c6faeb7e462288e2a026e6b1e43d" # rocm6.3 ) set(__AOTRITON_Z "gz") From 438d6595d90ce6ae18eba30740a955c31200dc16 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 28 Jan 2025 18:34:23 +0000 Subject: [PATCH 03/27] [ROCm] Bump AOTriton to 0.8.2b (#145508) We received reports AOTriton kernels mishandles the bias pointer and it causes NaN during fine-tuning llama3.2-11b vision model. This PR will fix the problem. Note: this AOTriton 0.8.1b adds head dimension 512 support and thus the binary size increases, but it is considered experimental and will not be enabled right now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145508 Approved by: https://github.com/jeffdaily --- cmake/External/aotriton.cmake | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 47b8874c7ca26..aacdbe2e955e7 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -22,7 +22,7 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.8b") + set(__AOTRITON_VER "0.8.2b") set(__AOTRITON_MANYLINUX_LIST "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 @@ -31,10 +31,10 @@ if(NOT __AOTRITON_INCLUDED) "rocm6.2" "rocm6.3" ) - set(__AOTRITON_CI_COMMIT "6f8cbcac8a92775291bb1ba8f514d4beb350baf4") + set(__AOTRITON_CI_COMMIT "b24f43a9771622faa157155568b9a200c3b49e41") set(__AOTRITON_SHA256_LIST - "e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8" # rocm6.2 - "dc03d531ca399250b7d2fbdfa61929871788c6faeb7e462288e2a026e6b1e43d" # rocm6.3 + "66445e6b0209b9f4080743b839cc9d424054dc5c8d07363f9f27f109231c324a" # rocm6.2 + "16356dc1813cf3e60da23eb2f29440cb35eedd3a2fbf81e6093a0bc42056ad08" # rocm6.3 ) set(__AOTRITON_Z "gz") From 326b5a238c49476f94715daa4494fc72daf6fce3 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 7 Mar 2025 22:10:07 +0000 Subject: [PATCH 04/27] Backport AOTriton 0.9.2b This is backporting the following commit: [ROCm] Bump AOTriton to 0.9.2b (#148433) Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.9b: * Optimize these Non-power-of-two head dimensions: 48, 80, 96, 160, 192, 224. Inputs with these head dimensions do not need padding to power-of-two anymore. * `is_causal=True` cases are now supported with persistent dynamic algorithm, which requires an atomic tensor but does load balance between different CTAs * `dropout_p > 0.0` cases now support full 64-bit offsets and use all i64x4 PRNG outputs * The precise AOTriton shared library version can now be identified with `readelf -p .comment libaotriton_v2.so` + However, this does not guarantee the GPU images stored under `aotriton.images` have the same version, since they can be overwritten. * The newly added fused backward kernel will be used for smaller workloads, due to less kernel invocation overhead. * Support gfx1201 (RX 9070XT). Need to be enabled at runtime with `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148433 Approved by: https://github.com/jeffdaily --- .../native/transformers/cuda/attention.cu | 104 ++- .../transformers/cuda/attention_backward.cu | 97 ++- .../native/transformers/cuda/sdp_utils.cpp | 33 +- .../transformers/hip/aotriton_adapter.h | 6 + .../hip/flash_attn/aot/mha_all_aot.hip | 790 ++++++++++++++++++ .../transformers/hip/flash_attn/flash_api.h | 655 +++++++++++++++ .../transformers/hip/flash_attn/flash_api.hip | 504 ----------- cmake/External/aotriton.cmake | 11 +- 8 files changed, 1632 insertions(+), 568 deletions(-) create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h delete mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 0f9356a7f3063..d2e831c1a3f11 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1077,8 +1077,8 @@ std::tuple _efficient_ auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" - " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs" + " (gfx90a/gfx942/gfx1100/gfx1201)") } // AOTriton may accept aligned on logsumexp tensor in the future for better @@ -1086,10 +1086,13 @@ std::tuple _efficient_ // compute_logsumexp is false constexpr int kAlignLSE = 1; res = at::empty({B, M, num_heads, Kv}, query.options()); + at::Tensor softmax_lse; logsumexp = at::empty( - { B, num_heads, max_seqlen_q }, + { B, num_heads, compute_logsumexp ? max_seqlen_q : 0}, query.options().dtype(at::ScalarType::Float)); - at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); + if (compute_logsumexp) { + softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); + } at::Tensor q_t = query.transpose(1, 2); at::Tensor k_t = key.transpose(1, 2); at::Tensor v_t = value.transpose(1, 2); @@ -1105,40 +1108,68 @@ std::tuple _efficient_ const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, query.options().dtype(at::kInt)); + } + using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16); + aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32); at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options()); const bool use_philox_state = in_capture_stream; auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; - auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); - auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); + auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); + auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); hipError_t err; // TODO: Error handling - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, - softmax_scale, - mk_aotensor<2>(softmax_lse, "M"), - mk_aotensor(output_t, "Out"), - dropout_p, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); - if (!compute_logsumexp) { - // Set the tensor to empty when compute_logsumexp is false - logsumexp = at::empty( - { B * num_heads, max_seqlen_q, 0 }, - query.options().dtype(at::ScalarType::Float)); + if (seqstart_q.has_value()) { + // varlen aka nested tensor + err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + softmax_scale, + compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2, + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + } else { + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2, + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); } #else // CUDA Implementation @@ -1401,15 +1432,24 @@ at::Tensor& _fill_mem_eff_dropout_mask_( #if defined(USE_MEM_EFF_ATTENTION) #ifdef USE_ROCM - using aotriton::v2::flash::debug_fill_dropout_rng; + using aotriton::v2::flash::debug_simulate_encoded_softmax; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + at::cuda::CUDAGuard device_guard(self.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + at::Tensor seed_t, offset_t; + const auto options = at::dtype(at::kLong).device(at::kCUDA); + seed_t = at::scalar_tensor(at::Scalar(seed), options); + offset_t = at::scalar_tensor(at::Scalar(offset), options); hipError_t err; // TODO: Error handling - err = debug_fill_dropout_rng(mk_aotensor(self, "r"), - static_cast(seed), - static_cast(offset), - stream); + err = debug_simulate_encoded_softmax(mk_aotensor(self, "r"), + dropout_p, + mk_aoscalartensor(seed_t), + mk_aoscalartensor(offset_t), + 0, + stream); #else at::PhiloxCudaState rng_engine_inputs; rng_engine_inputs = at::PhiloxCudaState(seed, offset); diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index e809f97265774..017661a3f3637 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -383,8 +383,8 @@ _efficient_attention_backward( auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" - " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs" + " (gfx90a/gfx942/gfx1100/gfx1201)") } const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); bool is_causal; @@ -404,33 +404,86 @@ _efficient_attention_backward( at::Tensor dv_t = grad_v.permute({0,2,1,3}); at::Tensor dout_t = grad_out.permute({0,2,1,3}); at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q}); - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); hipError_t err; using aotriton::v2::flash::attn_bwd; + using aotriton::v2::flash::attn_bwd_fused; + using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, - mk_aotensor<2>(softmax_lse, "L"), - mk_aotensor<2>(delta, "delta"), - float(dropout_p), - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); + if (cu_seqlens_q.has_value()) { + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + // varlen aka Nested tensor + err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + auto d_head = Kv; + bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512; + if (use_fused_bwd) { + err = attn_bwd_fused(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } + } #else at::Tensor workspace; cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index a61d95312fbe3..45e202b8c0e3f 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -26,6 +26,11 @@ #endif #endif +// Avoid potential compiler -Wall -Werror complains undefined macro +#ifndef AOTRITON_VERSION_MINOR +#define AOTRITON_VERSION_MINOR 0 +#endif + /** * Note [SDPA Runtime Dispatch] * SDPA relies on a runtime dispatch mechanism to select the appropriate @@ -83,8 +88,13 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { } bool check_head_dim_size_flash(sdp_params const& params, bool debug) { +#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9 + // AOTriton 0.9+ supports head_dim up to 512 + const auto max_size = c10::SymInt(512); +#else // All head_dim sizes must be equal and less than 256 const auto max_size = c10::SymInt(256); +#endif const auto query_size_last = params.query.sym_size(-1); const auto key_size_last = params.key.sym_size(-1); const auto value_size_last = params.value.sym_size(-1); @@ -207,6 +217,16 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); return false; } +#if AOTRITON_VERSION_MINOR >= 9 + if (aotriton::isArchExperimentallySupported(stream)) { + static const bool enable_experimental = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_experimental) { + TORCH_WARN_ONCE("Flash Efficient attention on Current AMD GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); + return false; + } + } +#endif } #else return false; @@ -243,15 +263,16 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { - static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; - if (!enable_navi3x) { - TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental." - " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); +#if AOTRITON_VERSION_MINOR >= 9 + if (aotriton::isArchExperimentallySupported(stream)) { + static const bool enable_experimental = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_experimental) { + TORCH_WARN_ONCE("Mem Efficient attention on Current AMD GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); return false; } } +#endif #else return false; #endif diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index 57d5c34444390..1623852b249fe 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -127,6 +127,12 @@ inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) aotriton::DType::kUInt64); // AOTriton excepts unsigned int64 } +inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kInt32); +} + } // namespace aotriton_adapter } // namespace sdp diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip new file mode 100644 index 0000000000000..adaa837f755cb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -0,0 +1,790 @@ +/****************************************************************************** + * Copyright (c) 2023, Advanced Micro Devices, Inc. + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS + +#include +#include + +#include + +#ifdef USE_FLASH_ATTENTION +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include +#include + +#include + +// AOTriton headers +#include +#include +#include + +#if AOTRITON_VERSION_MINOR != 9 +#error "This adaptor code is only tested with AOTriton 0.9.x" +#endif + +namespace pytorch_flash { + +namespace { + +void check_gpu_arch(hipStream_t stream) { + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") + } +} + +// We want to checkpoint and save the RNG state for backward if dropout +// We get the default generator and return the seed and offset which will +// be used in the backward function +std::tuple +prepare_philox_arguments(float p_dropout, int64_t counter_offset) { + at::Tensor seed_t, offset_t; + at::PhiloxCudaState philox_state; + bool use_philox_state = false; + if (p_dropout <= 0.0) { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + return { seed_t, offset_t, philox_state, use_philox_state }; + } + auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + std::lock_guard lock(gen->mutex_); + philox_state = gen->philox_cuda_state(counter_offset); + if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { + auto [seed, offset] = at::cuda::philox::unpack(philox_state); + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); + } else { + // See Note [CUDA Graph-safe RNG states] about the design + use_philox_state = true; + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + } + + return { seed_t, offset_t, philox_state, use_philox_state }; +} + + +} + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +std::tuple +mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_) { + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + // FIXME: ROCM probably does not need this + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + q_padded = q; + k_padded = k; + v_padded = v; + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } + } else { + out = at::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); + + // Transpose tensors to meet AOTriton's Flash API + at::Tensor q_t = q_padded.permute({0,2,1,3}); + at::Tensor k_t = k_padded.permute({0,2,1,3}); + at::Tensor v_t = v_padded.permute({0,2,1,3}); + at::Tensor output_t = out.permute({0,2,1,3}); + + auto opts = q.options(); + at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, opts.dtype(at::kFloat)); // aka softmax_lse + + at::Tensor softmax_fa_t; + if (return_softmax) { + softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } else { + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); + } + + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); + } + + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_fwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); + auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(output_t, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + + return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; +} + +std::tuple +mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_) { + TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); + const bool paged_KV = block_table_.has_value(); + TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt"); + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); + + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { + is_causal = false; + } // causal=true is the same as causal=false in this case + + at::Tensor temp_q = q; + const int total_q = temp_q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= max_seqlen_k) { + window_size_right = -1; + } + + CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + // AOTriton's varlen API needs input shapes be + // (1, num_heads, total sequence lenght, head dimension) + at::Tensor q_padded, k_padded, v_padded; + at::Tensor out, out_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + } else { + out = at::empty_like(q); + } + out_padded = out.unsqueeze(0).transpose(1, 2); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = head_size_og; + + auto opts = q.options(); + + auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor M = softmax_lse.view({batch_size * num_heads, max_seqlen_q}); + at::Tensor softmax_fa_t; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + softmax_fa_t = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); + } else { + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) { + softmax_fa_t.zero_(); + } + } + + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); + + if (max_seqlen_k > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_fwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::cast_dtype; + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, q.options()); + } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto nullscalar = mk_philoxtensor(nullptr); + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; + auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; + err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + empty_bias, + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(out_padded, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + return {out, q, k, v, softmax_lse, seed_t, offset_t, softmax_fa_t}; +} + +std::tuple +mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + bool is_dropout = p_dropout > 0.0; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + + if (is_causal){ + TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); + } + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dv = at::empty_like(k); + } + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor q_t = q.permute({0,2,1,3}); + at::Tensor k_t = k.permute({0,2,1,3}); + at::Tensor v_t = v.permute({0,2,1,3}); + at::Tensor out_t = out.permute({0,2,1,3}); + at::Tensor dq_t = dq.permute({0,2,1,3}); + at::Tensor dk_t = dk.permute({0,2,1,3}); + at::Tensor dv_t = dv.permute({0,2,1,3}); + at::Tensor dout_t = dout.permute({0,2,1,3}); + + at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, seqlen_q}).contiguous(); + + int d_head = head_size_og; + bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512; + hipError_t err; // TODO: Error handling + if (use_fused_bwd) { + using aotriton::v2::flash::attn_bwd_fused; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_fused(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + empty_bias, // dbb + mk_aotensor<2>(softmax_lse_cont, "L"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + empty_bias, // db + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } + + return { dq, dk, dv, softmax_d }; +} + +std::tuple +mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset) +{ + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); + + if (is_causal) { + window_size_right = 0; + } + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + bool is_dropout = p_dropout > 0.0; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous(); + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + + at::Tensor q_padded, k_padded, v_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + at::Tensor out_t, dout_t; + out_t = out.unsqueeze(0).transpose(1, 2); + dout_t = dout.unsqueeze(0).transpose(1, 2); + + at::Tensor dq, dk, dv; + at::Tensor dq_padded, dk_padded, dv_padded; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } else { + dv = at::empty_like(v); + } + dq_padded = dq.unsqueeze(0).transpose(1, 2); + dk_padded = dk.unsqueeze(0).transpose(1, 2); + dv_padded = dv.unsqueeze(0).transpose(1, 2); + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + if( zero_tensors ) { + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + at::PhiloxCudaState philox_args; + if (is_dropout) { + if (at::cuda::currentStreamCaptureStatus() == + at::cuda::CaptureStatus::None) + { + philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); + } else { // dropout + capture + philox_args = at::PhiloxCudaState( + philox_seed.data_ptr(), philox_offset.data_ptr(), 0); + } + } + if (max_seqlen_q > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_bwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_padded, "dq"), + mk_aotensor(dk_padded, "dk"), + mk_aotensor(dv_padded, "dv"), + empty_bias, + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash + +#endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 0000000000000..4daaa66e8a1a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,655 @@ +#pragma once +#include + +#include +#include +#include + +namespace pytorch_flash { + +// AOTriton Implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_aot( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_aot( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& block_table_, + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple mha_bwd_aot( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +std::tuple mha_varlen_bwd_aot( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +#if defined(USE_CK_FLASH_ATTENTION) +// CK implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_ck( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); // batch_size x nheads x seqlen_q x seqlen_k + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_ck( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); + +std::tuple mha_bwd_ck( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + attn_bias_, // batch_size x num_heads x seqlen_q x seqlen_k + bool bias_requires_grad, + std::optional& grad_bias, + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple mha_varlen_bwd_ck( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& attn_bias_, // num_heads or b x num_heads + bool bias_requires_grad, + std::optional& grad_bias, + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#endif + +TORCH_API +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } else { + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } +#else + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +#endif +} + +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& + block_table_, // Not used on ROCm. Keeping for parity with CUDA + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional dummy_attn_bias = std::nullopt; + return mha_varlen_fwd_ck( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } else { + return mha_varlen_fwd_aot( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } +#else + return mha_varlen_fwd_aot( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +#endif +} + +inline std::tuple mha_bwd( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional non_null_dbias = std::nullopt; + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } else { + return mha_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } +#else + if(at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); + } + return mha_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +#endif +} + +inline std::tuple mha_varlen_bwd( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional non_null_dbias = std::nullopt; + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_varlen_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } else { + return mha_varlen_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } +#else + return mha_varlen_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +#endif +} + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip deleted file mode 100644 index 9b0820a501bf4..0000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ /dev/null @@ -1,504 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Advanced Micro Devices, Inc. - * Copyright (c) 2022, Tri Dao. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ -#include -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS - -#include -#include - -#include - -#ifdef USE_FLASH_ATTENTION -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#include -#include -#include -#include -#include -#include -#endif - -#include -#include - -#include -#include - -// AOTriton headers -#include -#include - -namespace pytorch_flash { - -namespace { - -void check_gpu_arch(hipStream_t stream) { - auto ret = aotriton::v2::flash::check_gpu(stream); - if (hipSuccess != ret) { - TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" - " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") - } -} - -} - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -std::tuple -mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - c10::optional gen_) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - - // FIXME: ROCM probably does not need this - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - int seqlen_q = sizes[1]; - int num_heads = sizes[2]; - const int head_size_og = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!"); - TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - if (is_causal) { window_size_right = 0; } - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); - - at::Tensor q_padded, k_padded, v_padded; - q_padded = q; - k_padded = k; - v_padded = v; - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } - } else { - out = at::empty_like(q_padded); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - - // We want to checkpoint and save the RNG state for backward if dropout - // We get the default generator and return the seed and offset which will - // be used in the backward function - auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::Tensor seed_t, offset_t; - - at::PhiloxCudaState philox_state; - bool use_philox_state = false; - if (p_dropout > 0.0) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = batch_size * num_heads * 32; - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - philox_state = gen->philox_cuda_state(counter_offset); - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { - auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); - } else { - // See Note [CUDA Graph-safe RNG states] about the design - use_philox_state = true; - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } else { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*seed_t.data_ptr(), *offset_t.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(seed_t.data_ptr(), offset_t.data_ptr(), 0); - } - } - - // Transpose tensors to meet AOTriton's Flash API - at::Tensor q_t = q_padded.permute({0,2,1,3}); - at::Tensor k_t = k_padded.permute({0,2,1,3}); - at::Tensor v_t = v_padded.permute({0,2,1,3}); - at::Tensor output_t = out.permute({0,2,1,3}); - - at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse - - at::Tensor softmax_fa_t; - if (return_softmax) { - softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, - at::dtype(q.dtype()).device(q.device())); - } else { - softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device())); - } - - hipError_t err; // TODO: Error handling - using aotriton::v2::flash::attn_fwd; - using aotriton::TensorView; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::mk_aoscalartensor; - using sdp::aotriton_adapter::mk_philoxtensor; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); - auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); - auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; - auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); - auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(output_t, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); - - return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t}; -} - -std::tuple -mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - c10::optional &alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, - const int max_seqlen_k, - const float p_dropout, - const float softmax_scale, - const bool zero_tensors, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - c10::optional gen_) { - - TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm"); - - at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat)); - at::Tensor p = at::empty({}, at::dtype(at::kFloat)); - at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor out = at::empty({}, at::dtype(at::kFloat)); - - return {out, q, k, v, softmax_lse, seed_t, offset_t, p}; -} - -std::tuple -mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, // probability to drop - const float softmax_scale, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); - - bool is_dropout = p_dropout > 0.0; - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - const int num_heads = sizes[2]; - const int head_size_og = dout.size(3); - const int head_size = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - - if (is_causal){ - TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); - } - - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); - TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); - - at::Tensor dq, dk, dv; - if (dq_.has_value()) { - dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); - CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); - } else { - dq = at::empty_like(q); - } - if (dk_.has_value()) { - dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); - CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dk = at::empty_like(k); - } - if (dv_.has_value()) { - dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dv = at::empty_like(k); - } - - // const at::Tensor& dout_padded = dout; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - - auto opts = q.options(); - auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - - at::Tensor dk_expanded, dv_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(philox_seed.data_ptr(), philox_offset.data_ptr(), 0); - } - } - - at::Tensor q_t = q.permute({0,2,1,3}); - at::Tensor k_t = k.permute({0,2,1,3}); - at::Tensor v_t = v.permute({0,2,1,3}); - at::Tensor out_t = out.permute({0,2,1,3}); - at::Tensor dq_t = dq.permute({0,2,1,3}); - at::Tensor dk_t = dk.permute({0,2,1,3}); - at::Tensor dv_t = dv.permute({0,2,1,3}); - at::Tensor dout_t = dout.permute({0,2,1,3}); - - at::Tensor softmax_lse_cont = softmax_lse.contiguous(); - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); - - int d_head = head_size_og; - hipError_t err; // TODO: Error handling - { - using aotriton::v2::flash::attn_bwd; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::mk_aoscalartensor; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - empty_bias, - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } - - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - } - return { dq, dk, dv, softmax_d }; -#undef CALL_BWD_DROPOUT -#undef CALL_BWD -} - -std::tuple -mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x s softmax logsumexp - c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &alibi_slopes_, // num_heads or b x num_heads - const int max_seqlen_q, - const int max_seqlen_k, // max sequence length to choose the kernel - const float p_dropout, // probability to drop - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset) { - TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); - - at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat)); - - return { q, k, v, softmax_d }; -} -} // namespace pytorch_fmha - -#endif diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index aacdbe2e955e7..fd165d8330a2f 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -22,19 +22,22 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.8.2b") + set(__AOTRITON_VER "0.9.2b") set(__AOTRITON_MANYLINUX_LIST "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 + "manylinux_2_28" # rocm6.4 ) set(__AOTRITON_ROCM_LIST "rocm6.2" "rocm6.3" + "rocm6.4" ) - set(__AOTRITON_CI_COMMIT "b24f43a9771622faa157155568b9a200c3b49e41") + set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525") set(__AOTRITON_SHA256_LIST - "66445e6b0209b9f4080743b839cc9d424054dc5c8d07363f9f27f109231c324a" # rocm6.2 - "16356dc1813cf3e60da23eb2f29440cb35eedd3a2fbf81e6093a0bc42056ad08" # rocm6.3 + "08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2 + "9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3 + "41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4 ) set(__AOTRITON_Z "gz") From cef2cbc0bee2193707315c65075e3db0e78fe18f Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 8 May 2025 14:31:58 -0500 Subject: [PATCH 05/27] AOTriton: add 0.9.2b version built on ROCM 6.5, with gfx950 supported. (#2105) Also fixes the URL problem, where release page does not always match the version string in file name. --- cmake/External/aotriton.cmake | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index fd165d8330a2f..12a1d07662930 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -22,22 +22,31 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.9.2b") + set(__AOTRITON_RELEASE_PAGE "0.9.2b") + set(__AOTRITON_VER_LIST + "0.9.2b" # rocm6.2 + "0.9.2b" # rocm6.3 + "0.9.2b" # rocm6.4 + "0.9.2b_612896439f" # rocm6.5 with gfx950 + ) set(__AOTRITON_MANYLINUX_LIST "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 + "manylinux_2_28" # rocm6.5 ) set(__AOTRITON_ROCM_LIST "rocm6.2" "rocm6.3" "rocm6.4" + "rocm6.5" ) - set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525") + set(__AOTRITON_CI_COMMIT "612896439fb4f78509b1a566b5ef0a333e9585bb") # source of rocm6.5 with gfx950 set(__AOTRITON_SHA256_LIST "08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2 "9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3 "41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4 + "c85da64d21510190277794455ef8bd3f2d543a6f2462140d3da27e1df0ab8f82" # rocm6.5 with gfx950 ) set(__AOTRITON_Z "gz") @@ -67,13 +76,14 @@ if(NOT __AOTRITON_INCLUDED) list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_ROCM}" __AOTRITON_ROCM_INDEX) list(GET __AOTRITON_SHA256_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_SHA256) list(GET __AOTRITON_MANYLINUX_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_MANYLINUX) + list(GET __AOTRITON_VER_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_VER) set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) string(CONCAT __AOTRITON_FILE "aotriton-" "${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}" "_${__AOTRITON_ARCH}-rocm${__AOTRITON_ROCM}" "-shared.tar.${__AOTRITON_Z}") string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" - "${__AOTRITON_VER}/${__AOTRITON_FILE}") + "${__AOTRITON_RELEASE_PAGE}/${__AOTRITON_FILE}") ExternalProject_Add(aotriton_external URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} From c584d44b1feab6358d04d7b3950c7a1985384661 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 1 Jul 2025 14:37:29 -0500 Subject: [PATCH 06/27] [release/2.7] [AOTriton] Support ROCM 7.0 ABI (#2302) Per request from SWDEV-540108 --- cmake/External/aotriton.cmake | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 12a1d07662930..df3102ff3587a 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -28,18 +28,21 @@ if(NOT __AOTRITON_INCLUDED) "0.9.2b" # rocm6.3 "0.9.2b" # rocm6.4 "0.9.2b_612896439f" # rocm6.5 with gfx950 + "0.9.2b_612896439f" # rocm7.0 ) set(__AOTRITON_MANYLINUX_LIST "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 "manylinux_2_28" # rocm6.5 + "manylinux_2_28" # rocm7.0 ) set(__AOTRITON_ROCM_LIST "rocm6.2" "rocm6.3" "rocm6.4" "rocm6.5" + "rocm7.0" ) set(__AOTRITON_CI_COMMIT "612896439fb4f78509b1a566b5ef0a333e9585bb") # source of rocm6.5 with gfx950 set(__AOTRITON_SHA256_LIST @@ -47,6 +50,7 @@ if(NOT __AOTRITON_INCLUDED) "9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3 "41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4 "c85da64d21510190277794455ef8bd3f2d543a6f2462140d3da27e1df0ab8f82" # rocm6.5 with gfx950 + "9061bff8a1f7b857399467260b54714d659fd812a41eeee049f0a3e9c8b9aeeb" # rocm7.0 ) set(__AOTRITON_Z "gz") From f943edb60f4b28e3bb148f71295bbc35c4974a79 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 20:50:42 +0000 Subject: [PATCH 07/27] Try to fix linking error --- test/cpp/c10d/CMakeLists.txt | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 5c8974836de56..024bf69113f3d 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -79,10 +79,7 @@ if(USE_MPI AND USE_C10D_MPI) # private headers of libtorch, which in turn include MPI. As a hacky # alternative to making MPI a public dependency of libtorch, we make it # a private dependency of the tests as well. - c10d_add_test(ProcessGroupMPITest.cpp torch_cpu MPI::MPI_CXX) - if(INSTALL_TEST) - install(TARGETS ProcessGroupMPITest DESTINATION bin) - endif() + c10d_add_test(ProcessGroupMPITest.cpp LINK_LIBRARIES torch_cpu MPI::MPI_CXX INSTALL_TEST ${INSTALL_TEST}) endif() if(LINUX AND USE_GLOO AND USE_C10D_GLOO) From e6b625a41d5d6bf75e4fd13c8bd1664e6da27cf3 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 15:57:48 -0500 Subject: [PATCH 08/27] Add missing using aotriton::v2::flash::attn_fwd_compact_varlen --- aten/src/ATen/native/transformers/cuda/attention.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index d2e831c1a3f11..84b07e010cd67 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1114,6 +1114,7 @@ std::tuple _efficient_ } using aotriton::v2::flash::attn_fwd; + using aotriton::v2::flash::attn_fwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; From 2c56ccd53a4dd8ac34339aa1f62c02848ed66c20 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 15:58:08 -0500 Subject: [PATCH 09/27] Revert "Try to fix linking error" This reverts commit 2211aace36d46e98a0081e2ea91ef8c16818157c. --- test/cpp/c10d/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 024bf69113f3d..5c8974836de56 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -79,7 +79,10 @@ if(USE_MPI AND USE_C10D_MPI) # private headers of libtorch, which in turn include MPI. As a hacky # alternative to making MPI a public dependency of libtorch, we make it # a private dependency of the tests as well. - c10d_add_test(ProcessGroupMPITest.cpp LINK_LIBRARIES torch_cpu MPI::MPI_CXX INSTALL_TEST ${INSTALL_TEST}) + c10d_add_test(ProcessGroupMPITest.cpp torch_cpu MPI::MPI_CXX) + if(INSTALL_TEST) + install(TARGETS ProcessGroupMPITest DESTINATION bin) + endif() endif() if(LINUX AND USE_GLOO AND USE_C10D_GLOO) From 110fead28df5516dd81677e16fafe9c7ef0d5fee Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Fri, 25 Oct 2024 09:38:08 +0000 Subject: [PATCH 10/27] Set RUNPATH so installed tests can find the required shared libraries (#136627) This change fixes the RUNPATH of installed c++ tests so that the linker can find the shared libraries they depend on. For example, currently: ```bash venv/lib/python3.10/site-packages/torch $ ./bin/test_lazy ./bin/test_lazy: error while loading shared libraries: libtorch.so: cannot open shared object file: No such file or directory ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136627 Approved by: https://github.com/malfet --- c10/benchmark/CMakeLists.txt | 1 + c10/test/CMakeLists.txt | 1 + caffe2/CMakeLists.txt | 6 ++++ test/cpp/api/CMakeLists.txt | 1 + test/cpp/c10d/CMakeLists.txt | 51 +++++++++++++-------------- test/cpp/dist_autograd/CMakeLists.txt | 1 + test/cpp/jit/CMakeLists.txt | 1 + test/cpp/lazy/CMakeLists.txt | 1 + test/cpp/rpc/CMakeLists.txt | 1 + test/cpp/tensorexpr/CMakeLists.txt | 2 ++ test/edge/CMakeLists.txt | 1 + 11 files changed, 40 insertions(+), 27 deletions(-) diff --git a/c10/benchmark/CMakeLists.txt b/c10/benchmark/CMakeLists.txt index 16b268e3800a0..8dee635d7e1d7 100644 --- a/c10/benchmark/CMakeLists.txt +++ b/c10/benchmark/CMakeLists.txt @@ -8,6 +8,7 @@ if(BUILD_TEST) add_executable(${bench_name} "${bench_src}") target_link_libraries(${bench_name} ${C10_LIB} benchmark) if(INSTALL_TEST) + set_target_properties(${bench_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${bench_name} DESTINATION test) endif() endforeach() diff --git a/c10/test/CMakeLists.txt b/c10/test/CMakeLists.txt index 7f2a61246c6c6..83b5b17f9c8a6 100644 --- a/c10/test/CMakeLists.txt +++ b/c10/test/CMakeLists.txt @@ -12,6 +12,7 @@ if(BUILD_TEST) target_link_libraries(${test_name} ${C10_LIB} gmock gtest gtest_main) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3cb4b81f81504..615df5c04c82f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1775,6 +1775,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1795,6 +1796,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1816,6 +1818,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1837,6 +1840,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() @@ -1851,6 +1855,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1870,6 +1875,7 @@ if(BUILD_TEST) target_compile_options(${test_name} PRIVATE ${HIP_CXX_FLAGS}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index ceeb607d52a7d..a62c3ecb53efb 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -68,6 +68,7 @@ if(NOT MSVC) endif() if(INSTALL_TEST) + set_target_properties(test_api PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_api DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 5c8974836de56..17292790e0e63 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -6,36 +6,39 @@ if(USE_CUDA) endif() function(c10d_add_test test_src) + set(prefix ARG) + set(noValues) + set(singleValues INSTALL_TEST) + set(multiValues LINK_LIBRARIES) + + include(CMakeParseArguments) + cmake_parse_arguments(${prefix} "${noValues}" "${singleValues}" "${multiValues}" ${ARGN}) + get_filename_component(test_name ${test_src} NAME_WE) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) - target_link_libraries(${test_name} ${ARGN}) + target_link_libraries(${test_name} ${ARG_LINK_LIBRARIES}) if(NOT WIN32) target_link_libraries(${test_name} pthread) endif() add_test(NAME ${test_name} COMMAND $) + + if(ARG_INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") + install(TARGETS ${test_name} DESTINATION bin) + endif() endfunction() -c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) -c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) -if(INSTALL_TEST) - install(TARGETS FileStoreTest DESTINATION bin) - install(TARGETS TCPStoreTest DESTINATION bin) -endif() +c10d_add_test(FileStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) +c10d_add_test(TCPStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) if(NOT WIN32) - c10d_add_test(HashStoreTest.cpp torch_cpu gtest_main) - if(INSTALL_TEST) - install(TARGETS HashStoreTest DESTINATION bin) - endif() + c10d_add_test(HashStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) endif() if(USE_CUDA) if(USE_GLOO AND USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp torch_cpu c10d_cuda_test gtest_main) - if(INSTALL_TEST) - install(TARGETS ProcessGroupGlooTest DESTINATION bin) - endif() - c10d_add_test(ProcessGroupGlooAsyncTest.cpp torch_cpu c10d_cuda_test gtest_main) + c10d_add_test(ProcessGroupGlooTest.cpp LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main INSTALL_TEST ${INSTALL_TEST}) + c10d_add_test(ProcessGroupGlooAsyncTest.cpp LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main INSTALL_TEST ${INSTALL_TEST}) endif() if(USE_NCCL AND USE_C10D_NCCL) # NCCL is a private dependency of libtorch, but the tests include some @@ -44,13 +47,11 @@ if(USE_CUDA) # a private dependency of the tests as well. c10d_add_test( ProcessGroupNCCLTest.cpp - torch_cpu c10d_cuda_test gtest_main __caffe2_nccl) + LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_nccl INSTALL_TEST ${INSTALL_TEST}) c10d_add_test( ProcessGroupNCCLErrorsTest.cpp - torch_cpu c10d_cuda_test gtest_main __caffe2_nccl) + LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_nccl INSTALL_TEST ${INSTALL_TEST}) if(INSTALL_TEST) - install(TARGETS ProcessGroupNCCLTest DESTINATION bin) - install(TARGETS ProcessGroupNCCLErrorsTest DESTINATION bin) install(TARGETS c10d_cuda_test DESTINATION lib) endif() endif() @@ -61,15 +62,14 @@ if(USE_CUDA) # a private dependency of the tests as well. c10d_add_test( ProcessGroupUCCTest.cpp - torch_cpu c10d_cuda_test gtest_main __caffe2_ucc) + LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_ucc INSTALL_TEST ${INSTALL_TEST}) if(INSTALL_TEST) - install(TARGETS ProcessGroupUCCTest DESTINATION bin) install(TARGETS c10d_cuda_test DESTINATION lib) endif() endif() else() if(USE_GLOO AND USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp torch_cpu gtest_main) + c10d_add_test(ProcessGroupGlooTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST OFF) endif() endif() @@ -79,10 +79,7 @@ if(USE_MPI AND USE_C10D_MPI) # private headers of libtorch, which in turn include MPI. As a hacky # alternative to making MPI a public dependency of libtorch, we make it # a private dependency of the tests as well. - c10d_add_test(ProcessGroupMPITest.cpp torch_cpu MPI::MPI_CXX) - if(INSTALL_TEST) - install(TARGETS ProcessGroupMPITest DESTINATION bin) - endif() + c10d_add_test(ProcessGroupMPITest.cpp LINK_LIBRARIES torch_cpu MPI::MPI_CXX INSTALL_TEST ${INSTALL_TEST}) endif() if(LINUX AND USE_GLOO AND USE_C10D_GLOO) diff --git a/test/cpp/dist_autograd/CMakeLists.txt b/test/cpp/dist_autograd/CMakeLists.txt index 0ae6e3bef1410..6b5bba4b82086 100644 --- a/test/cpp/dist_autograd/CMakeLists.txt +++ b/test/cpp/dist_autograd/CMakeLists.txt @@ -14,6 +14,7 @@ if(USE_DISTRIBUTED AND NOT WIN32) endif() if(INSTALL_TEST) + set_target_properties(test_dist_autograd PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_dist_autograd DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index f0510d9c81f20..db6841d8d0547 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -151,6 +151,7 @@ elseif(USE_ROCM) endif() if(INSTALL_TEST) + set_target_properties(test_jit PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_jit DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/lazy/CMakeLists.txt b/test/cpp/lazy/CMakeLists.txt index be37b47ac9b92..2fa4fabdf54dc 100644 --- a/test/cpp/lazy/CMakeLists.txt +++ b/test/cpp/lazy/CMakeLists.txt @@ -44,6 +44,7 @@ elseif(USE_ROCM) endif() if(INSTALL_TEST) + set_target_properties(test_lazy PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_lazy DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/rpc/CMakeLists.txt b/test/cpp/rpc/CMakeLists.txt index 6834b428ff937..5c3a0dc020de9 100644 --- a/test/cpp/rpc/CMakeLists.txt +++ b/test/cpp/rpc/CMakeLists.txt @@ -37,6 +37,7 @@ if(USE_CUDA) endif() if(INSTALL_TEST) + set_target_properties(test_cpp_rpc PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_cpp_rpc DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt index 179270c4a4a15..8cf4803935106 100644 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -71,7 +71,9 @@ elseif(USE_ROCM) endif() if(INSTALL_TEST) + set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_tensorexpr DESTINATION bin) + set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS tutorial_tensorexpr DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/edge/CMakeLists.txt b/test/edge/CMakeLists.txt index 50579c9109dc8..72c01a2d36492 100644 --- a/test/edge/CMakeLists.txt +++ b/test/edge/CMakeLists.txt @@ -73,5 +73,6 @@ elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU") ) endif() if(INSTALL_TEST) + set_target_properties(test_edge_op_registration PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_edge_op_registration DESTINATION bin) endif() From 34adeb7cb39226cf4c31d50f6c90bc848483c0a3 Mon Sep 17 00:00:00 2001 From: cyyever Date: Tue, 29 Oct 2024 23:14:40 +0000 Subject: [PATCH 11/27] [CMake] Remove pthread linking (#134436) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134436 Approved by: https://github.com/r-barnes --- test/cpp/c10d/CMakeLists.txt | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 17292790e0e63..5eabec63aaed4 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -18,9 +18,6 @@ function(c10d_add_test test_src) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) target_link_libraries(${test_name} ${ARG_LINK_LIBRARIES}) - if(NOT WIN32) - target_link_libraries(${test_name} pthread) - endif() add_test(NAME ${test_name} COMMAND $) if(ARG_INSTALL_TEST) @@ -85,7 +82,7 @@ endif() if(LINUX AND USE_GLOO AND USE_C10D_GLOO) add_executable(example_allreduce example/allreduce.cpp) target_include_directories(example_allreduce PRIVATE $) - target_link_libraries(example_allreduce pthread torch_cpu) + target_link_libraries(example_allreduce torch_cpu) if(USE_CUDA) target_link_libraries(example_allreduce torch_cuda) endif() From b47f007afed75ef58172efdb9d4af8023e314df7 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 16:20:24 -0500 Subject: [PATCH 12/27] build: add missing file --- aten/src/ATen/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 0087dd95d96ee..8c8dc6662e00e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -174,6 +174,7 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") # flash_attention sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") +file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") #Mem_eff attention sources file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu") @@ -188,6 +189,7 @@ if(USE_FLASH_ATTENTION) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu}) list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip}) + list(APPEND native_transformers_hip_hip ${flash_attention_hip_aot_hip}) list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip}) endif() From 8e80a7b953af179fbe70214e426e226ad768bf69 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 16:22:15 -0500 Subject: [PATCH 13/27] do not hipify tools/amd_build/build_amd.py --- tools/amd_build/build_amd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 96047e61f0304..3b035df1d29e8 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -97,7 +97,6 @@ "aten/src/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h", - "aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h", "aten/src/THC/*", "aten/src/ATen/test/*", # CMakeLists.txt isn't processed by default, but there are a few From 102b3f3e7912aa04490455e490782a75b0b8efb6 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 21:22:31 +0000 Subject: [PATCH 14/27] Revert "[CMake] Remove pthread linking (#134436)" This reverts commit b7eaa03dc469a73e4fe10f93fa779180c96c763e. --- test/cpp/c10d/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 5eabec63aaed4..17292790e0e63 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -18,6 +18,9 @@ function(c10d_add_test test_src) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) target_link_libraries(${test_name} ${ARG_LINK_LIBRARIES}) + if(NOT WIN32) + target_link_libraries(${test_name} pthread) + endif() add_test(NAME ${test_name} COMMAND $) if(ARG_INSTALL_TEST) @@ -82,7 +85,7 @@ endif() if(LINUX AND USE_GLOO AND USE_C10D_GLOO) add_executable(example_allreduce example/allreduce.cpp) target_include_directories(example_allreduce PRIVATE $) - target_link_libraries(example_allreduce torch_cpu) + target_link_libraries(example_allreduce pthread torch_cpu) if(USE_CUDA) target_link_libraries(example_allreduce torch_cuda) endif() From a143ee557e4cb42fe1c936d017f77e0040be8a3e Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 21:22:33 +0000 Subject: [PATCH 15/27] Revert "Set RUNPATH so installed tests can find the required shared libraries (#136627)" This reverts commit 84040be83ee4f8850e2384065415c1f8c8e997a5. --- c10/benchmark/CMakeLists.txt | 1 - c10/test/CMakeLists.txt | 1 - caffe2/CMakeLists.txt | 6 ---- test/cpp/api/CMakeLists.txt | 1 - test/cpp/c10d/CMakeLists.txt | 51 ++++++++++++++------------- test/cpp/dist_autograd/CMakeLists.txt | 1 - test/cpp/jit/CMakeLists.txt | 1 - test/cpp/lazy/CMakeLists.txt | 1 - test/cpp/rpc/CMakeLists.txt | 1 - test/cpp/tensorexpr/CMakeLists.txt | 2 -- test/edge/CMakeLists.txt | 1 - 11 files changed, 27 insertions(+), 40 deletions(-) diff --git a/c10/benchmark/CMakeLists.txt b/c10/benchmark/CMakeLists.txt index 8dee635d7e1d7..16b268e3800a0 100644 --- a/c10/benchmark/CMakeLists.txt +++ b/c10/benchmark/CMakeLists.txt @@ -8,7 +8,6 @@ if(BUILD_TEST) add_executable(${bench_name} "${bench_src}") target_link_libraries(${bench_name} ${C10_LIB} benchmark) if(INSTALL_TEST) - set_target_properties(${bench_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${bench_name} DESTINATION test) endif() endforeach() diff --git a/c10/test/CMakeLists.txt b/c10/test/CMakeLists.txt index 83b5b17f9c8a6..7f2a61246c6c6 100644 --- a/c10/test/CMakeLists.txt +++ b/c10/test/CMakeLists.txt @@ -12,7 +12,6 @@ if(BUILD_TEST) target_link_libraries(${test_name} ${C10_LIB} gmock gtest gtest_main) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 615df5c04c82f..3cb4b81f81504 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1775,7 +1775,6 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1796,7 +1795,6 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1818,7 +1816,6 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1840,7 +1837,6 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() @@ -1855,7 +1851,6 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1875,7 +1870,6 @@ if(BUILD_TEST) target_compile_options(${test_name} PRIVATE ${HIP_CXX_FLAGS}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index a62c3ecb53efb..ceeb607d52a7d 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -68,7 +68,6 @@ if(NOT MSVC) endif() if(INSTALL_TEST) - set_target_properties(test_api PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_api DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 17292790e0e63..5c8974836de56 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -6,39 +6,36 @@ if(USE_CUDA) endif() function(c10d_add_test test_src) - set(prefix ARG) - set(noValues) - set(singleValues INSTALL_TEST) - set(multiValues LINK_LIBRARIES) - - include(CMakeParseArguments) - cmake_parse_arguments(${prefix} "${noValues}" "${singleValues}" "${multiValues}" ${ARGN}) - get_filename_component(test_name ${test_src} NAME_WE) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) - target_link_libraries(${test_name} ${ARG_LINK_LIBRARIES}) + target_link_libraries(${test_name} ${ARGN}) if(NOT WIN32) target_link_libraries(${test_name} pthread) endif() add_test(NAME ${test_name} COMMAND $) - - if(ARG_INSTALL_TEST) - set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS ${test_name} DESTINATION bin) - endif() endfunction() -c10d_add_test(FileStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) -c10d_add_test(TCPStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) +c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) +c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) +if(INSTALL_TEST) + install(TARGETS FileStoreTest DESTINATION bin) + install(TARGETS TCPStoreTest DESTINATION bin) +endif() if(NOT WIN32) - c10d_add_test(HashStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) + c10d_add_test(HashStoreTest.cpp torch_cpu gtest_main) + if(INSTALL_TEST) + install(TARGETS HashStoreTest DESTINATION bin) + endif() endif() if(USE_CUDA) if(USE_GLOO AND USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main INSTALL_TEST ${INSTALL_TEST}) - c10d_add_test(ProcessGroupGlooAsyncTest.cpp LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main INSTALL_TEST ${INSTALL_TEST}) + c10d_add_test(ProcessGroupGlooTest.cpp torch_cpu c10d_cuda_test gtest_main) + if(INSTALL_TEST) + install(TARGETS ProcessGroupGlooTest DESTINATION bin) + endif() + c10d_add_test(ProcessGroupGlooAsyncTest.cpp torch_cpu c10d_cuda_test gtest_main) endif() if(USE_NCCL AND USE_C10D_NCCL) # NCCL is a private dependency of libtorch, but the tests include some @@ -47,11 +44,13 @@ if(USE_CUDA) # a private dependency of the tests as well. c10d_add_test( ProcessGroupNCCLTest.cpp - LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_nccl INSTALL_TEST ${INSTALL_TEST}) + torch_cpu c10d_cuda_test gtest_main __caffe2_nccl) c10d_add_test( ProcessGroupNCCLErrorsTest.cpp - LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_nccl INSTALL_TEST ${INSTALL_TEST}) + torch_cpu c10d_cuda_test gtest_main __caffe2_nccl) if(INSTALL_TEST) + install(TARGETS ProcessGroupNCCLTest DESTINATION bin) + install(TARGETS ProcessGroupNCCLErrorsTest DESTINATION bin) install(TARGETS c10d_cuda_test DESTINATION lib) endif() endif() @@ -62,14 +61,15 @@ if(USE_CUDA) # a private dependency of the tests as well. c10d_add_test( ProcessGroupUCCTest.cpp - LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_ucc INSTALL_TEST ${INSTALL_TEST}) + torch_cpu c10d_cuda_test gtest_main __caffe2_ucc) if(INSTALL_TEST) + install(TARGETS ProcessGroupUCCTest DESTINATION bin) install(TARGETS c10d_cuda_test DESTINATION lib) endif() endif() else() if(USE_GLOO AND USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST OFF) + c10d_add_test(ProcessGroupGlooTest.cpp torch_cpu gtest_main) endif() endif() @@ -79,7 +79,10 @@ if(USE_MPI AND USE_C10D_MPI) # private headers of libtorch, which in turn include MPI. As a hacky # alternative to making MPI a public dependency of libtorch, we make it # a private dependency of the tests as well. - c10d_add_test(ProcessGroupMPITest.cpp LINK_LIBRARIES torch_cpu MPI::MPI_CXX INSTALL_TEST ${INSTALL_TEST}) + c10d_add_test(ProcessGroupMPITest.cpp torch_cpu MPI::MPI_CXX) + if(INSTALL_TEST) + install(TARGETS ProcessGroupMPITest DESTINATION bin) + endif() endif() if(LINUX AND USE_GLOO AND USE_C10D_GLOO) diff --git a/test/cpp/dist_autograd/CMakeLists.txt b/test/cpp/dist_autograd/CMakeLists.txt index 6b5bba4b82086..0ae6e3bef1410 100644 --- a/test/cpp/dist_autograd/CMakeLists.txt +++ b/test/cpp/dist_autograd/CMakeLists.txt @@ -14,7 +14,6 @@ if(USE_DISTRIBUTED AND NOT WIN32) endif() if(INSTALL_TEST) - set_target_properties(test_dist_autograd PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_dist_autograd DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index db6841d8d0547..f0510d9c81f20 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -151,7 +151,6 @@ elseif(USE_ROCM) endif() if(INSTALL_TEST) - set_target_properties(test_jit PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_jit DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/lazy/CMakeLists.txt b/test/cpp/lazy/CMakeLists.txt index 2fa4fabdf54dc..be37b47ac9b92 100644 --- a/test/cpp/lazy/CMakeLists.txt +++ b/test/cpp/lazy/CMakeLists.txt @@ -44,7 +44,6 @@ elseif(USE_ROCM) endif() if(INSTALL_TEST) - set_target_properties(test_lazy PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_lazy DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/rpc/CMakeLists.txt b/test/cpp/rpc/CMakeLists.txt index 5c3a0dc020de9..6834b428ff937 100644 --- a/test/cpp/rpc/CMakeLists.txt +++ b/test/cpp/rpc/CMakeLists.txt @@ -37,7 +37,6 @@ if(USE_CUDA) endif() if(INSTALL_TEST) - set_target_properties(test_cpp_rpc PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_cpp_rpc DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt index 8cf4803935106..179270c4a4a15 100644 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -71,9 +71,7 @@ elseif(USE_ROCM) endif() if(INSTALL_TEST) - set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_tensorexpr DESTINATION bin) - set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS tutorial_tensorexpr DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/edge/CMakeLists.txt b/test/edge/CMakeLists.txt index 72c01a2d36492..50579c9109dc8 100644 --- a/test/edge/CMakeLists.txt +++ b/test/edge/CMakeLists.txt @@ -73,6 +73,5 @@ elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU") ) endif() if(INSTALL_TEST) - set_target_properties(test_edge_op_registration PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_edge_op_registration DESTINATION bin) endif() From e75771d2f6fed7baee6f50cd90a5b7cb78d61aa0 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 21:42:18 +0000 Subject: [PATCH 16/27] fix build error --- .../transformers/hip/flash_attn/flash_api.h | 203 ------------------ 1 file changed, 203 deletions(-) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 4daaa66e8a1a2..d503d6bb2ee93 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -273,39 +273,6 @@ mha_fwd( const float softcap, const bool return_softmax, std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - std::optional dummy_attn_bias = std::nullopt; - return mha_fwd_ck( - q, - k, - v, - out_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_, - dummy_attn_bias); // Not used in flash attention - } else { - return mha_fwd_aot( - q, - k, - v, - out_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); - } -#else return mha_fwd_aot( q, k, @@ -319,7 +286,6 @@ mha_fwd( window_size_right, return_softmax, gen_); -#endif } inline std::tuple< @@ -359,52 +325,6 @@ mha_varlen_fwd( const float softcap, const bool return_softmax, std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - std::optional dummy_attn_bias = std::nullopt; - return mha_varlen_fwd_ck( - q, - k, - v, - out_, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_, - dummy_attn_bias); // Not used in flash attention - } else { - return mha_varlen_fwd_aot( - q, - k, - v, - out_, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - block_table_, - alibi_slopes_, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); - } -#else return mha_varlen_fwd_aot( q, k, @@ -425,7 +345,6 @@ mha_varlen_fwd( window_size_right, return_softmax, gen_); -#endif } inline std::tuple mha_bwd( @@ -452,63 +371,6 @@ inline std::tuple mha_bwd( const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { -#if defined(USE_CK_FLASH_ATTENTION) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - std::optional non_null_dbias = std::nullopt; - auto[dQuery, - dKey, - dValue, - dSoftmax, - dBias] = mha_bwd_ck( - dout, - q, - k, - v, - out, - softmax_lse, - dq_, - dk_, - dv_, - alibi_slopes_, - false, // bias_requires_grad - non_null_dbias, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - deterministic, - philox_seed, - philox_offset); - // for FA return [dQ, dV, dK, dSoftmax] - return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); - } else { - return mha_bwd_aot( - dout, - q, - k, - v, - out, - softmax_lse, - dq_, - dk_, - dv_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - deterministic, - philox_seed, - philox_offset); - } -#else - if(at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); - } return mha_bwd_aot( dout, q, @@ -528,7 +390,6 @@ inline std::tuple mha_bwd( deterministic, philox_seed, philox_offset); -#endif } inline std::tuple mha_varlen_bwd( @@ -562,69 +423,6 @@ inline std::tuple mha_varlen_bwd const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { -#if defined(USE_CK_FLASH_ATTENTION) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - std::optional non_null_dbias = std::nullopt; - auto[dQuery, - dKey, - dValue, - dSoftmax, - dBias] = mha_varlen_bwd_ck( - dout, - q, - k, - v, - out, - softmax_lse, - dq_, - dk_, - dv_, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes_, - false, // bias_requires_grad - non_null_dbias, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - deterministic, - philox_seed, - philox_offset); - // for FA return [dQ, dV, dK, dSoftmax] - return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); - } else { - return mha_varlen_bwd_aot( - dout, - q, - k, - v, - out, - softmax_lse, - dq_, - dk_, - dv_, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes_, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - deterministic, - philox_seed, - philox_offset); - } -#else return mha_varlen_bwd_aot( dout, q, @@ -649,7 +447,6 @@ inline std::tuple mha_varlen_bwd deterministic, philox_seed, philox_offset); -#endif } } // namespace pytorch_flash From d62e294bc81cd41fc1a823b16a8573a961e06703 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 21:44:41 +0000 Subject: [PATCH 17/27] fix build error --- .../transformers/hip/flash_attn/aot/mha_all_aot.hip | 5 +---- .../ATen/native/transformers/hip/flash_attn/flash_api.h | 8 -------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index adaa837f755cb..3dbb7448896d5 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -255,7 +255,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, @@ -268,8 +267,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot const bool return_softmax, const std::optional& gen_) { TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); - const bool paged_KV = block_table_.has_value(); - TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt"); TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; @@ -299,7 +296,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size_og = sizes[2]; - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + const int num_heads_k = k.size(1); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index d503d6bb2ee93..b0dc8d2576a48 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -57,7 +57,6 @@ mha_varlen_fwd_aot( std::optional& seqused_k, // b. If given, only this many elements of each batch // element's keys are used. - std::optional& block_table_, std::optional& alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, @@ -270,7 +269,6 @@ mha_fwd( bool is_causal, int window_size_left, int window_size_right, - const float softcap, const bool return_softmax, std::optional gen_) { return mha_fwd_aot( @@ -311,8 +309,6 @@ mha_varlen_fwd( std::optional& seqused_k, // b. If given, only this many elements of each batch // element's keys are used. - std::optional& - block_table_, // Not used on ROCm. Keeping for parity with CUDA std::optional& alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, @@ -322,7 +318,6 @@ mha_varlen_fwd( bool is_causal, int window_size_left, int window_size_right, - const float softcap, const bool return_softmax, std::optional gen_) { return mha_varlen_fwd_aot( @@ -333,7 +328,6 @@ mha_varlen_fwd( cu_seqlens_q, cu_seqlens_k, seqused_k, - block_table_, alibi_slopes_, max_seqlen_q, max_seqlen_k, @@ -367,7 +361,6 @@ inline std::tuple mha_bwd( const bool is_causal, int window_size_left, int window_size_right, - const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { @@ -419,7 +412,6 @@ inline std::tuple mha_varlen_bwd const bool is_causal, int window_size_left, int window_size_right, - const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { From 611f5b2c1cc43c68dab938049d90dd4d3d9f7e77 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 17:08:26 -0500 Subject: [PATCH 18/27] fix "file INSTALL cannot make directory" when build with non-root users --- caffe2/CMakeLists.txt | 5 ----- 1 file changed, 5 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3cb4b81f81504..395b2695ff879 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1915,9 +1915,4 @@ if(BUILD_PYTHON) add_custom_target(python_copy_files ALL DEPENDS ${build_files}) - - # Install commands - # Pick up static python files - install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${PYTHON_LIB_REL_PATH} - FILES_MATCHING PATTERN "*.py") endif() From faa52356a83d0e1d8fe7630627651286f0c594d2 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 17:13:40 -0500 Subject: [PATCH 19/27] add missing aotriton.images --- setup.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/setup.py b/setup.py index e263fffd5e1c8..12fb8b063ad6f 100644 --- a/setup.py +++ b/setup.py @@ -1396,6 +1396,13 @@ def main(): "lib/*.lib", ] ) + aotriton_image_path = os.path.join(lib_path, "aotriton.images") + aks2_files = [] + for root, dirs, files in os.walk(aotriton_image_path): + subpath = os.path.relpath(root, start=aotriton_image_path) + for fn in files: + aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn)) + torch_package_data += aks2_files if get_cmake_cache_vars()["BUILD_CAFFE2"]: torch_package_data.extend( [ From 604c22e2d5e90002faecf51b968c919bcf61f3e3 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 17:30:01 -0500 Subject: [PATCH 20/27] remove files newer than release/2.4 --- .ci/docker/libtorch/Dockerfile | 105 ------------- .ci/docker/manywheel/Dockerfile | 200 ------------------------ .ci/manywheel/build_rocm.sh | 268 -------------------------------- 3 files changed, 573 deletions(-) delete mode 100644 .ci/docker/libtorch/Dockerfile delete mode 100644 .ci/docker/manywheel/Dockerfile delete mode 100755 .ci/manywheel/build_rocm.sh diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile deleted file mode 100644 index 8737d753c9405..0000000000000 --- a/.ci/docker/libtorch/Dockerfile +++ /dev/null @@ -1,105 +0,0 @@ -ARG BASE_TARGET=base -ARG GPU_IMAGE=ubuntu:20.04 -FROM ${GPU_IMAGE} as base - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get clean && apt-get update -RUN apt-get install -y curl locales g++ git-all autoconf automake make cmake wget unzip sudo -# Just add everything as a safe.directory for git since these will be used in multiple places with git -RUN git config --global --add safe.directory '*' - -RUN locale-gen en_US.UTF-8 - -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -# Install openssl -FROM base as openssl -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh && rm install_openssl.sh - -# Install python -FROM base as python -ADD common/install_cpython.sh install_cpython.sh -RUN apt-get update -y && \ - apt-get install build-essential gdb lcov libbz2-dev libffi-dev \ - libgdbm-dev liblzma-dev libncurses5-dev libreadline6-dev \ - libsqlite3-dev libssl-dev lzma lzma-dev tk-dev uuid-dev zlib1g-dev -y && \ - bash ./install_cpython.sh && \ - rm install_cpython.sh && \ - apt-get clean - -FROM base as conda -ADD ./common/install_conda_docker.sh install_conda.sh -RUN bash ./install_conda.sh && rm install_conda.sh - -FROM base as cpu -# Install Anaconda -COPY --from=conda /opt/conda /opt/conda -# Install python -COPY --from=python /opt/python /opt/python -COPY --from=python /opt/_internal /opt/_internal -ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH -# Install MKL -ADD ./common/install_mkl.sh install_mkl.sh -RUN bash ./install_mkl.sh && rm install_mkl.sh - -FROM cpu as cuda -ADD ./common/install_cuda.sh install_cuda.sh -ADD ./common/install_magma.sh install_magma.sh -ENV CUDA_HOME /usr/local/cuda - -FROM cuda as cuda11.8 -RUN bash ./install_cuda.sh 11.8 -RUN bash ./install_magma.sh 11.8 -RUN ln -sf /usr/local/cuda-11.8 /usr/local/cuda - -FROM cuda as cuda12.1 -RUN bash ./install_cuda.sh 12.1 -RUN bash ./install_magma.sh 12.1 -RUN ln -sf /usr/local/cuda-12.1 /usr/local/cuda - -FROM cuda as cuda12.4 -RUN bash ./install_cuda.sh 12.4 -RUN bash ./install_magma.sh 12.4 -RUN ln -sf /usr/local/cuda-12.4 /usr/local/cuda - -FROM cuda as cuda12.6 -RUN bash ./install_cuda.sh 12.6 -RUN bash ./install_magma.sh 12.6 -RUN ln -sf /usr/local/cuda-12.6 /usr/local/cuda - -FROM cpu as rocm -ARG PYTORCH_ROCM_ARCH -ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} -ENV MKLROOT /opt/intel -# Adding ROCM_PATH env var so that LoadHip.cmake (even with logic updated for ROCm6.0) -# find HIP works for ROCm5.7. Not needed for ROCm6.0 and above. -# Remove below when ROCm5.7 is not in support matrix anymore. -ENV ROCM_PATH /opt/rocm -# No need to install ROCm as base docker image should have full ROCm install -#ADD ./common/install_rocm.sh install_rocm.sh -ADD ./common/install_rocm_drm.sh install_rocm_drm.sh -ADD ./common/install_rocm_magma.sh install_rocm_magma.sh -# gfortran and python needed for building magma from source for ROCm -RUN apt-get update -y && \ - apt-get install gfortran -y && \ - apt-get install python -y && \ - apt-get clean - -RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh -RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh - -FROM ${BASE_TARGET} as final -COPY --from=openssl /opt/openssl /opt/openssl -# Install patchelf -ADD ./common/install_patchelf.sh install_patchelf.sh -RUN bash ./install_patchelf.sh && rm install_patchelf.sh -# Install Anaconda -COPY --from=conda /opt/conda /opt/conda -# Install python -COPY --from=python /opt/python /opt/python -COPY --from=python /opt/_internal /opt/_internal -ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH diff --git a/.ci/docker/manywheel/Dockerfile b/.ci/docker/manywheel/Dockerfile deleted file mode 100644 index 04298fd0ed023..0000000000000 --- a/.ci/docker/manywheel/Dockerfile +++ /dev/null @@ -1,200 +0,0 @@ -# syntax = docker/dockerfile:experimental -ARG ROCM_VERSION=3.7 -ARG BASE_CUDA_VERSION=11.8 - -ARG GPU_IMAGE=centos:7 -FROM centos:7 as base - -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -ARG DEVTOOLSET_VERSION=9 - -# Note: This is required patch since CentOS have reached EOL -# otherwise any yum install setp will fail -RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo -RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo -RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo -RUN yum install -y wget curl perl util-linux xz bzip2 git patch which perl zlib-devel -# Just add everything as a safe.directory for git since these will be used in multiple places with git -RUN git config --global --add safe.directory '*' -RUN yum install -y yum-utils centos-release-scl -RUN yum-config-manager --enable rhel-server-rhscl-7-rpms -# Note: After running yum-config-manager --enable rhel-server-rhscl-7-rpms -# patch is required once again. Somehow this steps adds mirror.centos.org -RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo -RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo -RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo -RUN yum install -y devtoolset-${DEVTOOLSET_VERSION}-gcc devtoolset-${DEVTOOLSET_VERSION}-gcc-c++ devtoolset-${DEVTOOLSET_VERSION}-gcc-gfortran devtoolset-${DEVTOOLSET_VERSION}-binutils -ENV PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH -ENV LD_LIBRARY_PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH - -RUN yum --enablerepo=extras install -y epel-release - -# cmake-3.18.4 from pip -RUN yum install -y python3-pip && \ - python3 -mpip install cmake==3.18.4 && \ - ln -s /usr/local/bin/cmake /usr/bin/cmake - -RUN yum install -y autoconf aclocal automake make sudo - -FROM base as openssl -# Install openssl (this must precede `build python` step) -# (In order to have a proper SSL module, Python is compiled -# against a recent openssl [see env vars above], which is linked -# statically. We delete openssl afterwards.) -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh && rm install_openssl.sh - -# EPEL for cmake -FROM base as patchelf -# Install patchelf -ADD ./common/install_patchelf.sh install_patchelf.sh -RUN bash ./install_patchelf.sh && rm install_patchelf.sh -RUN cp $(which patchelf) /patchelf - -FROM patchelf as python -# build python -COPY manywheel/build_scripts /build_scripts -ADD ./common/install_cpython.sh /build_scripts/install_cpython.sh -RUN bash build_scripts/build.sh && rm -r build_scripts - -FROM base as cuda -ARG BASE_CUDA_VERSION=10.2 -# Install CUDA -ADD ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh - -FROM base as intel -# MKL -ADD ./common/install_mkl.sh install_mkl.sh -RUN bash ./install_mkl.sh && rm install_mkl.sh - -FROM base as magma -ARG BASE_CUDA_VERSION=10.2 -# Install magma -ADD ./common/install_magma.sh install_magma.sh -RUN bash ./install_magma.sh ${BASE_CUDA_VERSION} && rm install_magma.sh - -FROM base as jni -# Install java jni header -ADD ./common/install_jni.sh install_jni.sh -ADD ./java/jni.h jni.h -RUN bash ./install_jni.sh && rm install_jni.sh - -FROM base as libpng -# Install libpng -ADD ./common/install_libpng.sh install_libpng.sh -RUN bash ./install_libpng.sh && rm install_libpng.sh - -FROM ${GPU_IMAGE} as common -RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo -RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo -RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 -RUN yum install -y \ - aclocal \ - autoconf \ - automake \ - bison \ - bzip2 \ - curl \ - diffutils \ - file \ - git \ - make \ - patch \ - perl \ - unzip \ - util-linux \ - wget \ - which \ - xz \ - yasm -RUN yum install -y \ - https://repo.ius.io/ius-release-el7.rpm \ - https://ossci-linux.s3.amazonaws.com/epel-release-7-14.noarch.rpm - -RUN yum swap -y git git236-core -# git236+ would refuse to run git commands in repos owned by other users -# Which causes version check to fail, as pytorch repo is bind-mounted into the image -# Override this behaviour by treating every folder as safe -# For more details see https://github.com/pytorch/pytorch/issues/78659#issuecomment-1144107327 -RUN git config --global --add safe.directory "*" - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem -# Install LLVM version -COPY --from=openssl /opt/openssl /opt/openssl -COPY --from=python /opt/python /opt/python -COPY --from=python /opt/_internal /opt/_internal -COPY --from=python /opt/python/cp39-cp39/bin/auditwheel /usr/local/bin/auditwheel -COPY --from=intel /opt/intel /opt/intel -COPY --from=patchelf /usr/local/bin/patchelf /usr/local/bin/patchelf -COPY --from=jni /usr/local/include/jni.h /usr/local/include/jni.h -COPY --from=libpng /usr/local/bin/png* /usr/local/bin/ -COPY --from=libpng /usr/local/bin/libpng* /usr/local/bin/ -COPY --from=libpng /usr/local/include/png* /usr/local/include/ -COPY --from=libpng /usr/local/include/libpng* /usr/local/include/ -COPY --from=libpng /usr/local/lib/libpng* /usr/local/lib/ -COPY --from=libpng /usr/local/lib/pkgconfig /usr/local/lib/pkgconfig - -FROM common as cpu_final -ARG BASE_CUDA_VERSION=10.1 -ARG DEVTOOLSET_VERSION=9 -# Install Anaconda -ADD ./common/install_conda_docker.sh install_conda.sh -RUN bash ./install_conda.sh && rm install_conda.sh -ENV PATH /opt/conda/bin:$PATH -RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo -RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo -RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo - -RUN yum install -y yum-utils centos-release-scl -RUN yum-config-manager --enable rhel-server-rhscl-7-rpms -RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo -RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo -RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo -RUN yum install -y devtoolset-${DEVTOOLSET_VERSION}-gcc devtoolset-${DEVTOOLSET_VERSION}-gcc-c++ devtoolset-${DEVTOOLSET_VERSION}-gcc-gfortran devtoolset-${DEVTOOLSET_VERSION}-binutils -ENV PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH -ENV LD_LIBRARY_PATH=/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/devtoolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH - -# cmake is already installed inside the rocm base image, so remove if present -RUN rpm -e cmake || true -# cmake-3.18.4 from pip -RUN yum install -y python3-pip && \ - python3 -mpip install cmake==3.18.4 && \ - ln -s /usr/local/bin/cmake /usr/bin/cmake - -# ninja -RUN yum install -y ninja-build - -FROM cpu_final as cuda_final -RUN rm -rf /usr/local/cuda-${BASE_CUDA_VERSION} -COPY --from=cuda /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION} -COPY --from=magma /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION} -RUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda -ENV PATH=/usr/local/cuda/bin:$PATH - -FROM cpu_final as rocm_final -ARG ROCM_VERSION=3.7 -ARG PYTORCH_ROCM_ARCH -ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} -# Adding ROCM_PATH env var so that LoadHip.cmake (even with logic updated for ROCm6.0) -# find HIP works for ROCm5.7. Not needed for ROCm6.0 and above. -# Remove below when ROCm5.7 is not in support matrix anymore. -ENV ROCM_PATH /opt/rocm -ENV MKLROOT /opt/intel -# No need to install ROCm as base docker image should have full ROCm install -#ADD ./common/install_rocm.sh install_rocm.sh -#RUN ROCM_VERSION=${ROCM_VERSION} bash ./install_rocm.sh && rm install_rocm.sh -ADD ./common/install_rocm_drm.sh install_rocm_drm.sh -RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh -# cmake3 is needed for the MIOpen build -RUN ln -sf /usr/local/bin/cmake /usr/bin/cmake3 -ADD ./common/install_rocm_magma.sh install_rocm_magma.sh -RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh -ADD ./common/install_miopen.sh install_miopen.sh -RUN bash ./install_miopen.sh ${ROCM_VERSION} && rm install_miopen.sh diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh deleted file mode 100755 index 703248d44aa91..0000000000000 --- a/.ci/manywheel/build_rocm.sh +++ /dev/null @@ -1,268 +0,0 @@ -#!/usr/bin/env bash - -set -ex - -export ROCM_HOME=/opt/rocm -export MAGMA_HOME=$ROCM_HOME/magma -# TODO: libtorch_cpu.so is broken when building with Debug info -export BUILD_DEBUG_INFO=0 - -# TODO Are these all used/needed? -export TH_BINARY_BUILD=1 -export USE_STATIC_CUDNN=1 -export USE_STATIC_NCCL=1 -export ATEN_STATIC_CUDA=1 -export USE_CUDA_STATIC_LINK=1 -export INSTALL_TEST=0 # dont install test binaries into site-packages -# Set RPATH instead of RUNPATH when using patchelf to avoid LD_LIBRARY_PATH override -export FORCE_RPATH="--force-rpath" - -# Keep an array of cmake variables to add to -if [[ -z "$CMAKE_ARGS" ]]; then - # These are passed to tools/build_pytorch_libs.sh::build() - CMAKE_ARGS=() -fi -if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then - # These are passed to tools/build_pytorch_libs.sh::build_caffe2() - EXTRA_CAFFE2_CMAKE_FLAGS=() -fi - -# Determine ROCm version and architectures to build for -# -# NOTE: We should first check `DESIRED_CUDA` when determining `ROCM_VERSION` -if [[ -n "$DESIRED_CUDA" ]]; then - if ! echo "${DESIRED_CUDA}"| grep "^rocm" >/dev/null 2>/dev/null; then - export DESIRED_CUDA="rocm${DESIRED_CUDA}" - fi - # rocm3.7, rocm3.5.1 - ROCM_VERSION="$DESIRED_CUDA" - echo "Using $ROCM_VERSION as determined by DESIRED_CUDA" -else - echo "Must set DESIRED_CUDA" - exit 1 -fi - -# Package directories -WHEELHOUSE_DIR="wheelhouse$ROCM_VERSION" -LIBTORCH_HOUSE_DIR="libtorch_house$ROCM_VERSION" -if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then - if [[ -z "$BUILD_PYTHONLESS" ]]; then - PYTORCH_FINAL_PACKAGE_DIR="/remote/wheelhouse$ROCM_VERSION" - else - PYTORCH_FINAL_PACKAGE_DIR="/remote/libtorch_house$ROCM_VERSION" - fi -fi -mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" || true - -# To make version comparison easier, create an integer representation. -ROCM_VERSION_CLEAN=$(echo ${ROCM_VERSION} | sed s/rocm//) -save_IFS="$IFS" -IFS=. ROCM_VERSION_ARRAY=(${ROCM_VERSION_CLEAN}) -IFS="$save_IFS" -if [[ ${#ROCM_VERSION_ARRAY[@]} == 2 ]]; then - ROCM_VERSION_MAJOR=${ROCM_VERSION_ARRAY[0]} - ROCM_VERSION_MINOR=${ROCM_VERSION_ARRAY[1]} - ROCM_VERSION_PATCH=0 -elif [[ ${#ROCM_VERSION_ARRAY[@]} == 3 ]]; then - ROCM_VERSION_MAJOR=${ROCM_VERSION_ARRAY[0]} - ROCM_VERSION_MINOR=${ROCM_VERSION_ARRAY[1]} - ROCM_VERSION_PATCH=${ROCM_VERSION_ARRAY[2]} -else - echo "Unhandled ROCM_VERSION ${ROCM_VERSION}" - exit 1 -fi -ROCM_INT=$(($ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH)) - -# Required ROCm libraries -ROCM_SO_FILES=( - "libMIOpen.so" - "libamdhip64.so" - "libhipblas.so" - "libhipfft.so" - "libhiprand.so" - "libhipsolver.so" - "libhipsparse.so" - "libhsa-runtime64.so" - "libamd_comgr.so" - "libmagma.so" - "librccl.so" - "librocblas.so" - "librocfft.so" - "librocm_smi64.so" - "librocrand.so" - "librocsolver.so" - "librocsparse.so" - "libroctracer64.so" - "libroctx64.so" - "libhipblaslt.so" - "libhiprtc.so" -) - -if [[ $ROCM_INT -ge 60100 ]]; then - ROCM_SO_FILES+=("librocprofiler-register.so") -fi - -if [[ $ROCM_INT -ge 60200 ]]; then - ROCM_SO_FILES+=("librocm-core.so") -fi - -OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release` -if [[ "$OS_NAME" == *"CentOS Linux"* || "$OS_NAME" == *"AlmaLinux"* ]]; then - LIBGOMP_PATH="/usr/lib64/libgomp.so.1" - LIBNUMA_PATH="/usr/lib64/libnuma.so.1" - LIBELF_PATH="/usr/lib64/libelf.so.1" - if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then - LIBTINFO_PATH="/usr/lib64/libtinfo.so.5" - else - LIBTINFO_PATH="/usr/lib64/libtinfo.so.6" - fi - LIBDRM_PATH="/opt/amdgpu/lib64/libdrm.so.2" - LIBDRM_AMDGPU_PATH="/opt/amdgpu/lib64/libdrm_amdgpu.so.1" - if [[ $ROCM_INT -ge 60100 && $ROCM_INT -lt 60300 ]]; then - # Below libs are direct dependencies of libhipsolver - LIBSUITESPARSE_CONFIG_PATH="/lib64/libsuitesparseconfig.so.4" - if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then - LIBCHOLMOD_PATH="/lib64/libcholmod.so.2" - # Below libs are direct dependencies of libsatlas - LIBGFORTRAN_PATH="/lib64/libgfortran.so.3" - else - LIBCHOLMOD_PATH="/lib64/libcholmod.so.3" - # Below libs are direct dependencies of libsatlas - LIBGFORTRAN_PATH="/lib64/libgfortran.so.5" - fi - # Below libs are direct dependencies of libcholmod - LIBAMD_PATH="/lib64/libamd.so.2" - LIBCAMD_PATH="/lib64/libcamd.so.2" - LIBCCOLAMD_PATH="/lib64/libccolamd.so.2" - LIBCOLAMD_PATH="/lib64/libcolamd.so.2" - LIBSATLAS_PATH="/lib64/atlas/libsatlas.so.3" - # Below libs are direct dependencies of libsatlas - LIBQUADMATH_PATH="/lib64/libquadmath.so.0" - fi - MAYBE_LIB64=lib64 -elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then - LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1" - LIBNUMA_PATH="/usr/lib/x86_64-linux-gnu/libnuma.so.1" - LIBELF_PATH="/usr/lib/x86_64-linux-gnu/libelf.so.1" - if [[ $ROCM_INT -ge 50300 ]]; then - LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.6" - else - LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.5" - fi - LIBDRM_PATH="/usr/lib/x86_64-linux-gnu/libdrm.so.2" - LIBDRM_AMDGPU_PATH="/usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1" - if [[ $ROCM_INT -ge 60100 && $ROCM_INT -lt 60300 ]]; then - # Below libs are direct dependencies of libhipsolver - LIBCHOLMOD_PATH="/lib/x86_64-linux-gnu/libcholmod.so.3" - # Below libs are direct dependencies of libcholmod - LIBSUITESPARSE_CONFIG_PATH="/lib/x86_64-linux-gnu/libsuitesparseconfig.so.5" - LIBAMD_PATH="/lib/x86_64-linux-gnu/libamd.so.2" - LIBCAMD_PATH="/lib/x86_64-linux-gnu/libcamd.so.2" - LIBCCOLAMD_PATH="/lib/x86_64-linux-gnu/libccolamd.so.2" - LIBCOLAMD_PATH="/lib/x86_64-linux-gnu/libcolamd.so.2" - LIBMETIS_PATH="/lib/x86_64-linux-gnu/libmetis.so.5" - LIBLAPACK_PATH="/lib/x86_64-linux-gnu/liblapack.so.3" - LIBBLAS_PATH="/lib/x86_64-linux-gnu/libblas.so.3" - # Below libs are direct dependencies of libblas - LIBGFORTRAN_PATH="/lib/x86_64-linux-gnu/libgfortran.so.5" - LIBQUADMATH_PATH="/lib/x86_64-linux-gnu/libquadmath.so.0" - fi - MAYBE_LIB64=lib -fi -OS_SO_PATHS=($LIBGOMP_PATH $LIBNUMA_PATH\ - $LIBELF_PATH $LIBTINFO_PATH\ - $LIBDRM_PATH $LIBDRM_AMDGPU_PATH\ - $LIBSUITESPARSE_CONFIG_PATH\ - $LIBCHOLMOD_PATH $LIBAMD_PATH\ - $LIBCAMD_PATH $LIBCCOLAMD_PATH\ - $LIBCOLAMD_PATH $LIBSATLAS_PATH\ - $LIBGFORTRAN_PATH $LIBQUADMATH_PATH\ - $LIBMETIS_PATH $LIBLAPACK_PATH\ - $LIBBLAS_PATH) -OS_SO_FILES=() -for lib in "${OS_SO_PATHS[@]}" -do - file_name="${lib##*/}" # Substring removal of path to get filename - OS_SO_FILES[${#OS_SO_FILES[@]}]=$file_name # Append lib to array -done - -# rocBLAS library files -ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library -ROCBLAS_LIB_DST=lib/rocblas/library -ARCH=$(echo $PYTORCH_ROCM_ARCH | sed 's/;/|/g') # Replace ; seperated arch list to bar for grep -ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH) -OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx) -ROCBLAS_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) - -# hipblaslt library files -HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library -HIPBLASLT_LIB_DST=lib/hipblaslt/library -ARCH_SPECIFIC_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -E $ARCH) -OTHER_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -v gfx) -HIPBLASLT_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) - -# ROCm library files -ROCM_SO_PATHS=() -for lib in "${ROCM_SO_FILES[@]}" -do - file_path=($(find $ROCM_HOME/lib/ -name "$lib")) # First search in lib - if [[ -z $file_path ]]; then - if [ -d "$ROCM_HOME/lib64/" ]; then - file_path=($(find $ROCM_HOME/lib64/ -name "$lib")) # Then search in lib64 - fi - fi - if [[ -z $file_path ]]; then - file_path=($(find $ROCM_HOME/ -name "$lib")) # Then search in ROCM_HOME - fi - if [[ -z $file_path ]]; then - echo "Error: Library file $lib is not found." >&2 - exit 1 - fi - ROCM_SO_PATHS[${#ROCM_SO_PATHS[@]}]="$file_path" # Append lib to array -done - -DEPS_LIST=( - ${ROCM_SO_PATHS[*]} - ${OS_SO_PATHS[*]} -) - -DEPS_SONAME=( - ${ROCM_SO_FILES[*]} - ${OS_SO_FILES[*]} -) - -DEPS_AUX_SRCLIST=( - "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_SRC/}" - "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_SRC/}" - "/opt/amdgpu/share/libdrm/amdgpu.ids" -) - -DEPS_AUX_DSTLIST=( - "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_DST/}" - "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_DST/}" - "share/libdrm/amdgpu.ids" -) - -# MIOpen library files -MIOPEN_SHARE_SRC=$ROCM_HOME/share/miopen/db -MIOPEN_SHARE_DST=share/miopen/db -MIOPEN_SHARE_FILES=($(ls $MIOPEN_SHARE_SRC | grep -E $ARCH)) -DEPS_AUX_SRCLIST+=(${MIOPEN_SHARE_FILES[@]/#/$MIOPEN_SHARE_SRC/}) -DEPS_AUX_DSTLIST+=(${MIOPEN_SHARE_FILES[@]/#/$MIOPEN_SHARE_DST/}) - -# RCCL library files -RCCL_SHARE_SRC=$ROCM_HOME/share/rccl/msccl-algorithms -RCCL_SHARE_DST=share/rccl/msccl-algorithms -RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC)) -DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/}) -DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/}) - -echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}" - -SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" -if [[ -z "$BUILD_PYTHONLESS" ]]; then - BUILD_SCRIPT=build_common.sh -else - BUILD_SCRIPT=build_libtorch.sh -fi -source $SCRIPTPATH/${BUILD_SCRIPT} From 83f9fcac2b457dc32c50fc01f506029820f45124 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 7 Jul 2025 17:33:13 -0500 Subject: [PATCH 21/27] enable UT for arch supported by AOTriton 0.9.x --- torch/testing/_internal/common_cuda.py | 30 +++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index e93042e21929d..9d6c41949dcec 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -33,32 +33,36 @@ IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)]) -def CDNA2OrLater(): - if TEST_WITH_ROCM: - gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName - return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"}) - return False - -def evaluate_gfx_arch_exact(matching_arch): +def evaluate_gfx_arch_within(arch_list): if not torch.cuda.is_available(): return False gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName - arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) - return arch == matching_arch + effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) + # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack- + # Hence the matching should be done reversely + return any(arch in effective_arch for arch in arch_list) -GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-')) -GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')) +def CDNA2OrLater(): + return evaluate_gfx_arch_within(["gfx90a", "gfx942"]) def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') + arch_list = ["gfx90a", "gfx942", "gfx1100"] + version = _get_torch_rocm_version() + if version >= (6, 5): + arch_list += ["gfx950"] + return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return not IS_WINDOWS and SM80OrLater return False def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') + arch_list = ["gfx90a", "gfx942", "gfx1100"] + version = _get_torch_rocm_version() + if version >= (6, 5): + arch_list += ["gfx950"] + return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return True return False From ff5e4b1ccd56ff112e41288072e8d4a13d0b406b Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 8 Jul 2025 02:27:11 -0500 Subject: [PATCH 22/27] USE_ROCM_ATTENTION -> USE_AOTRITON --- aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 45e202b8c0e3f..34f1daa27b9ce 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -88,7 +88,7 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { } bool check_head_dim_size_flash(sdp_params const& params, bool debug) { -#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9 +#if USE_AOTRITON && AOTRITON_VERSION_MINOR >= 9 // AOTriton 0.9+ supports head_dim up to 512 const auto max_size = c10::SymInt(512); #else From 1d3b6c36504abd582aaca97dd70f725b8084eb67 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 8 Jul 2025 02:29:59 -0500 Subject: [PATCH 23/27] flash_api: remove _ck functions --- .../transformers/hip/flash_attn/flash_api.h | 122 ------------------ 1 file changed, 122 deletions(-) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index b0dc8d2576a48..cddd6dfb7a885 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -124,128 +124,6 @@ std::tuple mha_varlen_bwd_aot( const at::Tensor& philox_seed, const at::Tensor& philox_offset); -#if defined(USE_CK_FLASH_ATTENTION) -// CK implementation -TORCH_API -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor> -mha_fwd_ck( - const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size - std::optional& - out_, // batch_size x seqlen_q x num_heads x head_size - const float p_dropout, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - std::optional gen_, - const std::optional& attn_bias_); // batch_size x nheads x seqlen_q x seqlen_k - -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor> -mha_varlen_fwd_ck( - const at::Tensor& - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor& - k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& - v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional& - out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - std::optional& - seqused_k, // b. If given, only this many elements of each batch - // element's keys are used. - int max_seqlen_q, - const int max_seqlen_k, - const float p_dropout, - const float softmax_scale, - const bool zero_tensors, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - std::optional gen_, - const std::optional& attn_bias_); - -std::tuple mha_bwd_ck( - const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og - const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor& softmax_lse, // b x h x seqlen_q - std::optional& - dq_, // batch_size x seqlen_q x num_heads x head_size - std::optional& - dk_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional& - dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional& - attn_bias_, // batch_size x num_heads x seqlen_q x seqlen_k - bool bias_requires_grad, - std::optional& grad_bias, - const float p_dropout, // probability to drop - const float softmax_scale, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset); - -std::tuple mha_varlen_bwd_ck( - const at::Tensor& dout, // total_q x num_heads, x head_size - const at::Tensor& - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor& - k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& - v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& out, // total_q x num_heads x head_size - const at::Tensor& softmax_lse, // b x h x s softmax logsumexp - std::optional& - dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - std::optional& - dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional& - dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - std::optional& attn_bias_, // num_heads or b x num_heads - bool bias_requires_grad, - std::optional& grad_bias, - const int max_seqlen_q, - const int max_seqlen_k, // max sequence length to choose the kernel - const float p_dropout, // probability to drop - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset); -#endif - TORCH_API inline std::tuple< at::Tensor, From 6aa4ae30fe4931d29122c037e6f3472e30441bb6 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 15 Jul 2025 20:07:44 -0500 Subject: [PATCH 24/27] Use AOTriton 0.10b instead to pass all UTs --- cmake/External/aotriton.cmake | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index df3102ff3587a..49415bf5ebc2a 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -22,35 +22,31 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_RELEASE_PAGE "0.9.2b") + set(__AOTRITON_RELEASE_PAGE "0.10b") set(__AOTRITON_VER_LIST - "0.9.2b" # rocm6.2 - "0.9.2b" # rocm6.3 - "0.9.2b" # rocm6.4 - "0.9.2b_612896439f" # rocm6.5 with gfx950 - "0.9.2b_612896439f" # rocm7.0 + "0.10b" # rocm6.3 + "0.10b" # rocm6.4 + "0.10b" # rocm6.5 + "0.10b" # rocm7.0 ) set(__AOTRITON_MANYLINUX_LIST - "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 "manylinux_2_28" # rocm6.5 "manylinux_2_28" # rocm7.0 ) set(__AOTRITON_ROCM_LIST - "rocm6.2" "rocm6.3" "rocm6.4" "rocm6.5" "rocm7.0" ) - set(__AOTRITON_CI_COMMIT "612896439fb4f78509b1a566b5ef0a333e9585bb") # source of rocm6.5 with gfx950 + set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477") # source of rocm6.5 with gfx950 set(__AOTRITON_SHA256_LIST - "08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2 - "9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3 - "41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4 - "c85da64d21510190277794455ef8bd3f2d543a6f2462140d3da27e1df0ab8f82" # rocm6.5 with gfx950 - "9061bff8a1f7b857399467260b54714d659fd812a41eeee049f0a3e9c8b9aeeb" # rocm7.0 + "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 + "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 + "7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm6.5 + "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0 ) set(__AOTRITON_Z "gz") From b2fdd04d2f4cbf6a892a1c8aca0ad38ecc310afe Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 16 Jul 2025 01:11:16 +0000 Subject: [PATCH 25/27] fix test_invalid_fused_inputs_head_dim. AOTriton supports hdim <= 512 --- test/test_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transformers.py b/test/test_transformers.py index d3992a7776aa9..3b656e9f8e3ad 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1462,7 +1462,7 @@ def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend): make_tensor = partial(torch.rand, device=device, dtype=dtype) size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257) if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels - size = SdpaShape(2, 2, 3, 257) + size = SdpaShape(2, 2, 3, 513) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) From 714e850a48b7f84b8c69f1b837e4275b29c3c865 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 16 Jul 2025 01:15:01 +0000 Subject: [PATCH 26/27] AOTriton 0.10b needs slightly larger fudge factor for dq --- test/test_transformers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_transformers.py b/test/test_transformers.py index 3b656e9f8e3ad..38397e2dab25f 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2692,6 +2692,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0 query_fudge_factor = dropout_fudge_factor + if TEST_WITH_ROCM: + query_fudge_factor += 1.0 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) # TODO: Investigate why grad_k needs larger tolerances @@ -2814,6 +2816,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, mask_fudge_factor = 1.0 if attn_mask is None else 1.5 query_fudge_factor = dropout_fudge_factor + if TEST_WITH_ROCM: + query_fudge_factor += 1.0 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) # TODO: Investigate why grad_k needs larger tolerances From c3834b3e0166735c256bf998fee0979f8eee8ff4 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 16 Jul 2025 01:18:00 +0000 Subject: [PATCH 27/27] Fix the adaptor code --- .../native/transformers/hip/flash_attn/aot/mha_all_aot.hip | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 3dbb7448896d5..22203f22079e7 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -64,8 +64,8 @@ #include #include -#if AOTRITON_VERSION_MINOR != 9 -#error "This adaptor code is only tested with AOTriton 0.9.x" +#if AOTRITON_VERSION_MINOR < 9 +#error "This adaptor code is only tested with AOTriton 0.9+" #endif namespace pytorch_flash {