From f77d860bdeec894b8a7886025d72ed21ebe2f562 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Mon, 15 Sep 2025 16:13:03 +0000 Subject: [PATCH 1/4] [ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (#162330) Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton. Already tested to be working on Windows with TheRock. Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162330 Approved by: https://github.com/jeffdaily Co-authored-by: Scott Todd --- CMakeLists.txt | 4 +- .../native/transformers/cuda/attention.cu | 66 ++++++++++ .../transformers/hip/flash_attn/flash_api.h | 39 +----- cmake/External/aotriton.cmake | 113 +++++++++++++++++- tools/linter/dictionary.txt | 1 + 5 files changed, 179 insertions(+), 44 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ce7890f002d3..91181735750d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -873,7 +873,7 @@ cmake_dependent_option( "Whether to build the flash_attention kernel for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA OR USE_ROCM;NOT MSVC" + "(USE_CUDA AND NOT MSVC) OR USE_ROCM" OFF) cmake_dependent_option( @@ -908,7 +908,7 @@ cmake_dependent_option( # USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake # if(USE_ROCM) - if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) include(cmake/External/aotriton.cmake) endif() endif() diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index b8b43e0086c1..c2193f2378dd 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -95,6 +95,72 @@ #endif #endif +#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)) +namespace pytorch_flash +{ +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} +} +#endif + namespace at { namespace cuda::philox { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index f6f2240d4f09..71a195906597 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -270,7 +270,7 @@ std::tuple mha_varle #endif TORCH_API -inline std::tuple< +std::tuple< at::Tensor, at::Tensor, at::Tensor, @@ -294,42 +294,7 @@ mha_fwd( std::optional window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { -#if defined(USE_ROCM_CK_SDPA) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); - std::optional dummy_attn_bias = std::nullopt; - return mha_fwd_ck( - q, - k, - v, - out_, - p_dropout, - softmax_scale, - is_causal, - non_null_window_left, - non_null_window_right, - return_softmax, - gen_, - dummy_attn_bias); // Not used in flash attention - } -#endif - return mha_fwd_aot( - q, - k, - v, - out_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); -} + std::optional gen_); inline std::tuple< at::Tensor, diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 5d9158774654..4f7a79a78bfc 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -45,13 +45,88 @@ if(NOT __AOTRITON_INCLUDED) ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") + # Set the default __AOTRITON_LIB path + set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so") + if(WIN32) + set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib") + endif() + + function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + # Windows-specific dependencies - build these first + if(NOT noimage) + message(FATAL_ERROR "noimage must be ON for Windows builds") + endif() + # Build dlfcn-win32 + set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32") + set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install") + + ExternalProject_Add(${dlfcn-win32_external} + GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git + GIT_TAG v1.4.2 + PREFIX ${__DLFCN_WIN32_PREFIX} + INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_C_COMPILER=cl + -DCMAKE_CXX_COMPILER=cl + -DBUILD_SHARED_LIBS=ON + -DBUILD_TESTS=OFF + BUILD_BYPRODUCTS + "${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib" + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + ) + ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE) + + # Build xz/liblzma + set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz") + set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install") + + ExternalProject_Add(${xz_external} + GIT_REPOSITORY https://github.com/tukaani-project/xz.git + GIT_TAG v5.8.1 + PREFIX ${__XZ_PREFIX} + INSTALL_DIR ${__XZ_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DBUILD_SHARED_LIBS=ON + -DENABLE_NLS=OFF + -DXZ_TOOL_LZMAINFO=OFF + -DXZ_TOOL_XZ=OFF + -DXZ_TOOL_XZDEC=OFF + -DXZ_TOOL_LZMADEC=OFF + BUILD_BYPRODUCTS + "${__XZ_INSTALL_DIR}/lib/lzma.lib" + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + ) + ExternalProject_Add_Step(${xz_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE) + endfunction() + function(aotriton_build_from_source noimage project) if(noimage) SET(RECURSIVE "OFF") else() SET(RECURSIVE "ON") endif() + if(WIN32) + message(STATUS "Building AOTriton Windows dependencies") + aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + endif() message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}") + ExternalProject_Add(${project} GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_SUBMODULES_RECURSE ${RECURSIVE} @@ -65,12 +140,19 @@ if(NOT __AOTRITON_INCLUDED) -DAOTRITON_GPU_BUILD_TIMEOUT=0 -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NOIMAGE_MODE=${noimage} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + -DHIP_PLATFORM=amd + $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}> + $<$:-Dliblzma_DIR=${liblzma_DIR}> + BUILD_BYPRODUCTS + "${__AOTRITON_LIB}" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE USES_TERMINAL_INSTALL TRUE ) + if(WIN32) + add_dependencies(${project} dlfcn-win32_external xz_external) + endif() endfunction() set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) @@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED) INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + BUILD_BYPRODUCTS "${__AOTRITON_LIB}" ) message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") @@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED) string(CONCAT __AOTRITON_URL "${__AOTRITON_BASE_URL}" "${__AOTRITON_VER}/${__AOTRITON_FILE}") + + # Set up directories + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}) + set(__DOWNLOAD_NO_EXTRACT "") + set(__BUILD_COMMANDS "") + + # On Windows, we need custom tar extraction with UTF-8 support + if(WIN32) + set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE") + set(__BUILD_COMMANDS + COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}" + COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}" + ) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton) + endif() + ExternalProject_Add(${project} URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image} + DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} + ${__DOWNLOAD_NO_EXTRACT} + SOURCE_DIR ${__AOTRITON_EXTRACT_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" + ${__BUILD_COMMANDS} INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory - "${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}" + "${__AOTRITON_INSTALL_SOURCE_DIR}" "${__AOTRITON_INSTALL_DIR}" BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" @@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED) endforeach() endforeach() endif() - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so) + target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB}) target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 706881a8f10f..c4a250db0483 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -12,6 +12,7 @@ BU contiguities contiguity coo +DEPENDEES deser din dout From 31269836ac779904461ce1fa529d25fb72762103 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 18 Sep 2025 13:53:48 +0000 Subject: [PATCH 2/4] [ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (#162998) A few UT failures are caused by `HIPBLASLT_ALLOW_TF32` Fixes #157094 Fixes #157093 Fixes #157092 Fixes #157091 Fixes #157064 Fixes #157063 Fixes #157062 Fixes #157061 Fixes #157042 Fixes #157041 Fixes #157039 Fixes #157004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162998 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- aten/src/ATen/Context.cpp | 21 +------- test/dynamo/test_graph_region_tracker.py | 62 +++++++++--------------- test/dynamo/test_misc.py | 55 +++++++-------------- test/inductor/test_flex_decoding.py | 3 -- test/inductor/test_padding.py | 3 -- test/test_cuda.py | 52 -------------------- test/test_linalg.py | 30 +----------- test/test_transformers.py | 7 ++- torch/cuda/tunable.py | 1 - torch/testing/_internal/common_cuda.py | 10 +--- 10 files changed, 48 insertions(+), 196 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 4d48084b0ab8..7a8d02be530e 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -180,7 +180,7 @@ void Context::setUserEnabledNNPACK(bool e) { } bool Context::allowTF32CuDNN(const std::string& op) const { - if (op.size() == 0){ + if (op.empty()){ bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32"; bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32"; TORCH_CHECK( @@ -281,9 +281,6 @@ bool Context::userEnabledOverrideableSDP() const { static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG"; static constexpr const std::array cublas_deterministic_configs = {":4096:8", ":16:8"}; -#ifdef USE_ROCM -static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32"; -#endif bool Context::checkCuBLASConfigDeterministic() { // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config @@ -343,12 +340,6 @@ void Context::setImmediateMiopen(bool b) { } bool Context::allowTF32CuBLAS() const { -#ifdef USE_ROCM - const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); - if (allow_tf32 != true) { - return false; - } -#endif bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32"; TORCH_CHECK( @@ -362,14 +353,6 @@ bool Context::allowTF32CuBLAS() const { } void Context::setAllowTF32CuBLAS(bool b) { -#ifdef USE_ROCM - const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); - if (allow_tf32 != true) { - C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. " - << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it."; - return; - } -#endif float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee"); } @@ -443,7 +426,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string& std::string msg; auto iterp = _fp32_precisions.find(backend); TORCH_CHECK(iterp != _fp32_precisions.end()); - for (auto p : iterp->second) { + for (const auto& p : iterp->second) { msg += p; msg += " "; } diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py index e930ff787a9a..ce456596fd55 100644 --- a/test/dynamo/test_graph_region_tracker.py +++ b/test/dynamo/test_graph_region_tracker.py @@ -1,6 +1,5 @@ # Owner(s): ["module: dynamo"] import contextlib -import os import torch import torch.fx @@ -196,21 +195,6 @@ def fn(x, y, z): ) def test_mismatched_global_state(self): - @contextlib.contextmanager - def _hip_allow_tf32(): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - def inner_fn(x, y): x1 = x * 1 y1 = y + 1 @@ -251,31 +235,29 @@ def set_default_dtype_bfloat16(): def reset_default_dtype(): torch.set_default_dtype(old_dtype) - tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - with tf32_ctx(): - for ctx in [ - lambda: torch.set_grad_enabled(False), - torch.autograd.grad_mode.inference_mode, - lambda: torch.autograd.graph.disable_saved_tensors_hooks( - "This is not supported" - ), - # lambda: torch.set_num_threads(2), : Unsupported - (set_default_dtype_bfloat16, reset_default_dtype), - ( - lambda: torch.use_deterministic_algorithms(True), - lambda: torch.use_deterministic_algorithms(False), - ), - # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), - # lambda: torch.use_deterministic_algorithms(False)), : Unsupported - create_toggle_fns("allow_bf16_reduced_precision_reduction"), - create_toggle_fns("allow_fp16_reduced_precision_reduction"), - create_toggle_fns("allow_tf32"), - ]: - self.assertExpectedInline( - self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), - """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ + for ctx in [ + lambda: torch.set_grad_enabled(False), + torch.autograd.grad_mode.inference_mode, + lambda: torch.autograd.graph.disable_saved_tensors_hooks( + "This is not supported" + ), + # lambda: torch.set_num_threads(2), : Unsupported + (set_default_dtype_bfloat16, reset_default_dtype), + ( + lambda: torch.use_deterministic_algorithms(True), + lambda: torch.use_deterministic_algorithms(False), + ), + # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), + # lambda: torch.use_deterministic_algorithms(False)), : Unsupported + create_toggle_fns("allow_bf16_reduced_precision_reduction"), + create_toggle_fns("allow_fp16_reduced_precision_reduction"), + create_toggle_fns("allow_tf32"), + ]: + self.assertExpectedInline( + self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), + """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ [['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", - ) + ) def test_mutation_tracking_simple(self): def fn(x, y, z): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 1a9d8e8155e4..0a3891e2dc14 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -8421,43 +8421,24 @@ def write_state(state): def fn(x): return x + 1 - import contextlib - - @contextlib.contextmanager - def _hip_allow_tf32(): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - - tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - with tf32_ctx(): - initial_state = read_state() - y = torch.randn(10) - try: - for round in range(3): - for i in range(len(initial_state)): - new_state = [False] * len(initial_state) - new_state[i] = True - write_state(new_state) - assert read_state() == new_state - last_state.clear() - fn(y) - assert last_state == new_state - if round == 0: - assert cnt == i + 1 - else: - assert cnt == len(initial_state) - finally: - write_state(initial_state) + initial_state = read_state() + y = torch.randn(10) + try: + for round in range(3): + for i in range(len(initial_state)): + new_state = [False] * len(initial_state) + new_state[i] = True + write_state(new_state) + assert read_state() == new_state + last_state.clear() + fn(y) + assert last_state == new_state + if round == 0: + assert cnt == i + 1 + else: + assert cnt == len(initial_state) + finally: + write_state(initial_state) def test_grad_state_mutated(self): prior = torch.is_grad_enabled() diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 120d8d36b439..849aefff8a96 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -43,9 +43,6 @@ Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) -# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. -# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the -# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. if torch.version.hip: torch.set_float32_matmul_precision("highest") else: diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 9ef3a18e2423..c67bde87a369 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -109,9 +109,6 @@ def setUpClass(cls): if HAS_GPU: cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision() cls.prior_default_device = torch.get_default_device() - # In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. - # In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the - # logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. if torch.version.hip: torch.set_float32_matmul_precision("highest") else: diff --git a/test/test_cuda.py b/test/test_cuda.py index 7985a2cd9fe8..d293601fad13 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -762,53 +762,7 @@ def check_workspace_size(inp): torch._C._cuda_clearCublasWorkspaces() - @contextlib.contextmanager - def _hip_allow_tf32(self): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - - @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") - def test_hipblaslt_allow_tf32(self): - tf32_ctx = self._hip_allow_tf32 - with tf32_ctx(): - os.environ["HIPBLASLT_ALLOW_TF32"] = "0" - # Save original value of allow_tf32 - orig = torch.backends.cuda.matmul.allow_tf32 - # If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp - # then matmul.allow_tf32 will return False after this point even if - # HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed. - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - # Toggle torch.backends.cuda.matmul.allow_tf32 couple of times. - torch.backends.cuda.matmul.allow_tf32 = not orig - test1 = torch.backends.cuda.matmul.allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = orig - test2 = torch.backends.cuda.matmul.allow_tf32 - self.assertNotEqual(test1, test2) - # Restore original value of allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = orig - def test_cublas_allow_tf32_get_set(self): - """ - We only turn on TF32 for MI300 with a special env var. This is because TF32 - is only available in MI300+ and is in experimental mode (hipblaslt support - is current WIP) - """ - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - - with tf32_ctx(): - self._test_cublas_allow_tf32_get_set_inner() - - def _test_cublas_allow_tf32_get_set_inner(self): skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] ) @@ -823,12 +777,6 @@ def _test_cublas_allow_tf32_get_set_inner(self): torch.backends.cuda.matmul.allow_tf32 = orig def test_float32_matmul_precision_get_set(self): - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - - with tf32_ctx(): - self._test_float32_matmul_precision_get_set_inner() - - def _test_float32_matmul_precision_get_set_inner(self): orig = torch.get_float32_matmul_precision() skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] diff --git a/test/test_linalg.py b/test/test_linalg.py index 0f6c8f207421..31d4e0d1d92d 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -109,22 +109,6 @@ def get_tunableop_untuned_filename(): return untuned_filename class TestLinalg(TestCase): - @contextlib.contextmanager - def _hip_allow_tf32(self): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Environment variable will be removed in the future. - import os - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - def setUp(self): super().setUp() torch.backends.cuda.matmul.allow_tf32 = False @@ -5542,13 +5526,8 @@ def test_scaled_gemm_tunableop(self, device, dtype): @runOnRocmArch(MI300_ARCH) @dtypes(torch.float) def test_tf32_tunableop(self, device, dtype): - # Test TunableOp with TF32. Supported by hipblasLT on MI300+. - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Eventually this flag will go away. - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - try: - with self._tunableop_ctx(), tf32_ctx(): + with self._tunableop_ctx(): torch.backends.cuda.matmul.allow_tf32 = True torch.cuda.tunable.set_rotating_buffer_size(0) @@ -5611,13 +5590,8 @@ def test_tf32_offline_tunableop(self, device, dtype): # This test is the offline version of test_tf32_tunableop import os - # Test TunableOp with TF32. Supported by hipblasLT on MI300+. - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Eventually this flag will go away. - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - try: - with self._tunableop_ctx(), tf32_ctx(): + with self._tunableop_ctx(): torch.backends.cuda.matmul.allow_tf32 = True ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) diff --git a/test/test_transformers.py b/test/test_transformers.py index 5b240e1f046c..f037ad8aa58c 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -51,7 +51,6 @@ PLATFORM_SUPPORTS_CUDNN_ATTENTION, tf32_on_and_off, tf32_enabled, - ROCM_VERSION, ) if TEST_FAIRSEQ: @@ -340,7 +339,7 @@ def test_train_with_pad_and_catch_error(self, device): l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") - @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.001) @parametrize("attn_mask_dim", [2, 3, None]) @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) @@ -524,7 +523,7 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) self.assertEqual(fastpath_output_expanded, slowpath_output) - @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.001) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -1110,7 +1109,7 @@ def forward( return_all_hiddens=False, )[0] - @tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.003) @parametrize("input_dim,attn_mask_dim,is_causal", [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index c3982c33315e..d1ac7fad7480 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -591,7 +591,6 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None: transA = layout[1] == "T" dtype = dtype_dict.get(data_type) if data_type == "tf32": - # User must still set HIPBLASLT_ALLOW_TF32=1 torch.backends.cuda.matmul.allow_tf32 = True else: torch.backends.cuda.matmul.allow_tf32 = False diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index be284429114f..846d2b407684 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -181,9 +181,6 @@ def tf32_off(): @contextlib.contextmanager def tf32_on(self, tf32_precision=1e-5): - if torch.version.hip: - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 old_precision = self.precision try: @@ -192,11 +189,6 @@ def tf32_on(self, tf32_precision=1e-5): with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): yield finally: - if torch.version.hip: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul self.precision = old_precision @@ -246,7 +238,7 @@ def tf32_enabled(): # if device is specified, it will check if device is cuda # if dtype is specified, it will check if dtype is float32 or complex64 # tf32 and fp32 are different only when all the three checks pass -def tf32_on_and_off(tf32_precision=1e-5, only_if=True): +def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True): def with_tf32_disabled(self, function_call): with tf32_off(): function_call() From b2aa2750707eb894433e025a4359f9f327413889 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 22 Sep 2025 15:01:18 +0000 Subject: [PATCH 3/4] [ROCm] Fix environment variable AOTRITON_INSTALLED_PREFIX (#163373) Early assignment of `__AOTRITON_LIB` breaks the usage of environment variable `$AOTRITON_INSTALLED_PREFIX` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163373 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily --- cmake/External/aotriton.cmake | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 4f7a79a78bfc..f09f77bedb80 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -46,9 +46,10 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") # Set the default __AOTRITON_LIB path - set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so") - if(WIN32) - set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib") + if(NOT WIN32) + set(__AOTRITON_LIB "lib/libaotriton_v2.so") + else() + set(__AOTRITON_LIB "lib/aotriton_v2.lib") endif() function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) @@ -143,8 +144,7 @@ if(NOT __AOTRITON_INCLUDED) -DHIP_PLATFORM=amd $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}> $<$:-Dliblzma_DIR=${liblzma_DIR}> - BUILD_BYPRODUCTS - "${__AOTRITON_LIB}" + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE @@ -177,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED) INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_LIB}" + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" ) message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") @@ -267,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED) endforeach() endforeach() endif() - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB}) + target_link_libraries(__caffe2_aotriton INTERFACE "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}") target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED From 7286cf8a19fba6420029944ae0c35eb576ed650f Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 25 Sep 2025 17:14:16 +0000 Subject: [PATCH 4/4] [ROCm] Transformer/SDPA unit test parity (#163745) ## Major Changes * Efficient Attention on ROCM requires last dimensions of input tensors align with 16 bytes. - Unlike FA, ME does not pad input tensors in `scaled_dot_product_attention` and hence this is required. * Fix `atomic_counter` handling in varlen FA API * Unskips a few unit tests. Fixes #157120 Fixes #157121 Fixes #157122 Fixes #157167 Fixes #155217 Fixes #157043 Fixes #157060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163745 Approved by: https://github.com/jeffdaily --- .../native/transformers/cuda/sdp_utils.cpp | 22 +++++++++++++++++++ .../hip/flash_attn/aot/mha_all_aot.hip | 5 +++-- test/nn/test_multihead_attention.py | 2 -- test/test_flop_counter.py | 3 --- test/test_nn.py | 7 +----- test/test_transformers.py | 20 +++-------------- 6 files changed, 29 insertions(+), 30 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 660aee3647ce..8eec0de7773f 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -176,6 +176,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { } return false; } + if constexpr(caller_is_meff) { + bool is_half = (params.query.dtype() == at::kHalf) || + (params.query.dtype() == at::kBFloat16); + const int64_t alignment = is_half ? 8 : 4; + if (!(query_size_last % alignment == 0 && query_size_last > 0 && + value_size_last % alignment == 0 && value_size_last > 0)) { + if (debug) { + TORCH_WARN( + "Mem efficient attention requires last dimension of inputs to be divisible by ", + alignment, + ". ", + "Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead."); + } + return false; + } + } return true; } diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index b5b1ed429289..2467cb809fdb 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -462,10 +462,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; using sdp::aotriton_adapter::cast_dtype; at::Tensor atomic_counter; if (is_causal) { - atomic_counter = at::zeros({1}, q.options()); + atomic_counter = at::zeros({1}, q.options().dtype(at::kInt)); } aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); @@ -474,7 +475,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto nullscalar = mk_philoxtensor(nullptr); auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; - auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); if (uses_swa || AOTRITON_ALWAYS_V3_API) { #if AOTRITON_V3_API using aotriton::v3::flash::CausalType; diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py index c0419664d009..40dca90b1648 100644 --- a/test/nn/test_multihead_attention.py +++ b/test/nn/test_multihead_attention.py @@ -17,7 +17,6 @@ instantiate_parametrized_tests, parametrize as parametrize_test, run_tests, - skipIfRocm, TEST_NUMPY, TEST_WITH_CROSSREF, ) @@ -746,7 +745,6 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self): class TestMultiheadAttentionNNDeviceType(NNTestCase): - @skipIfRocm(msg="To investigate: yields NaN") def test_multihead_self_attn_two_masks_fast_path(self, device): """ Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index c44d5e5d4145..17e699e04e58 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -15,7 +15,6 @@ ) from torch.testing._internal.common_utils import ( run_tests, - skipIfRocm, TEST_WITH_TORCHDYNAMO, TestCase, ) @@ -463,7 +462,6 @@ def get_flops( self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -683,7 +681,6 @@ def split_tensor(x): ), ) - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, diff --git a/test/test_nn.py b/test/test_nn.py index c17f7cb668b6..d5c245c5887d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -39,7 +39,7 @@ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ skipIfTorchDynamo, gcIfJetson, set_default_dtype from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ - PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version + _get_torch_rocm_version from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input @@ -3166,7 +3166,6 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) - @skipIfRocm(msg='Large numerical errors') def test_transformerdecoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): d_model = 4 @@ -12998,8 +12997,6 @@ def test_skip_init(self, device): @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) def test_transformerencoderlayer(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer d_model = 4 nhead = 2 @@ -13221,8 +13218,6 @@ def test_transformerencoderlayer_fast_path(self, device, dtype): @dtypes(torch.float) @dtypesIfCUDA(torch.half, torch.float) def test_transformerencoderlayer_gelu(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer with gelu activation d_model = 4 nhead = 2 diff --git a/test/test_transformers.py b/test/test_transformers.py index f037ad8aa58c..b2a3959a5042 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -344,9 +344,6 @@ def test_train_with_pad_and_catch_error(self, device): @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): - if TEST_WITH_ROCM: - if attn_mask_dim is not None and mask_dtype == torch.bool: - self.skipTest("boolean mask is not fully supported on ROCm yet.") # MHA converts all with torch.no_grad(): B = 2 @@ -429,8 +426,7 @@ def hook(module, inputs, output): # remove hook handle.remove() - @skipIfRocm - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) @@ -1420,7 +1416,6 @@ def ones_tensor(*shape): _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True) torch.cuda.synchronize() - @skipIfRocm # Missing EFFICIENT_ATTENTION @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware" ) @@ -1713,7 +1708,7 @@ def test_unaligned_tensors(self, device): make_tensor = partial(torch.rand, size, device=device, dtype=dtype) q, k, v = make_tensor(), make_tensor(), make_tensor() with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): - ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() + ctxmgr = self.assertRaises(RuntimeError) with ctxmgr: torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) @@ -2611,7 +2606,6 @@ def convert_flash_attn_S_to_softmax( S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2635,7 +2629,6 @@ def test_cudnn_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_gqa(self, device): batch = 4 @@ -2659,7 +2652,6 @@ def test_cudnn_attention_gqa(self, device): self.assertEqual(output_math, output_cudnn) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 @@ -2690,7 +2682,6 @@ def test(): with self.assertRaisesRegex(RuntimeError, "No available kernel."): test() - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_fused_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2714,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") @unittest.skipIf(True, "broken as of cuDNN 9.10") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 @@ -2736,7 +2727,6 @@ def test_cudnn_attention_fail_d128(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel."): torch.nn.functional.scaled_dot_product_attention(q, k, v) - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_trivial_output_transpose(self, device): # see also: https://github.com/pytorch/pytorch/issues/134001 @@ -2752,7 +2742,6 @@ def test_cudnn_attention_trivial_output_transpose(self, device): o.backward(o) torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_nonmodulo64seqlen(self, device): # see also: https://github.com/pytorch/pytorch/issues/137347 @@ -2792,7 +2781,6 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device): torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) - @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_preserves_query_layout(self, device): @@ -2822,7 +2810,6 @@ def test_attention(backend: SDPBackend, permute_order: list[list[int]]): for permute_order in permute_orders: test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3]) - @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_compiles(self): q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True) @@ -3241,7 +3228,6 @@ def test_sdp_choice_with_determinism(self, device, warn_only): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value - @skipIfRocm @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")