diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fc51fa38289..944d7cd557e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -865,6 +865,13 @@ cmake_dependent_option( "USE_CUDA OR USE_ROCM;NOT MSVC" OFF) +cmake_dependent_option( + USE_CK_FLASH_ATTENTION + "Whether to build the CK flash_attention kernel. Will be enabled if USE_FLASH_ATTENTION is enabled." + ON + "USE_FLASH_ATTENTION" + OFF) + # We are currenlty not using alibi attention for Flash So we disable this # feature by default We dont currently document this feature because we don't # Suspect users building from source will need this @@ -888,6 +895,13 @@ if(USE_ROCM) endif() endif() +# CK shared lib linkage +if(USE_ROCM) + if(UNIX AND (USE_CK_FLASH_ATTENTION)) + include(cmake/External/ck.cmake) + endif() +endif() + if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index cff157c784c6..6378342367ff 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -172,22 +172,20 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") # if USE_FLASH_ATTENTION is set, ensure CK instances get generated if(USE_FLASH_ATTENTION) - if(DEFINED ENV{USE_CK_FLASH_ATTENTION}) - set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION}) - if(USE_CK_FLASH_ATTENTION STREQUAL "1") - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) - if(NUM_ARCHS GREATER 1) - message(WARNING "Building CK for multiple archs can increase build time considerably! - Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") - endif() - endif() - message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") - message(STATUS "Generating CK kernel instances...") - add_subdirectory(native/transformers/hip/flash_attn/ck) - file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + if(USE_CK_FLASH_ATTENTION) + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") endif() + endif() + message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") + message(STATUS "Generating CK kernel instances...") + # disable buidling CK files + # add_subdirectory(native/transformers/hip/flash_attn/ck) + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt deleted file mode 100644 index a72911cd510e..000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt +++ /dev/null @@ -1,63 +0,0 @@ -# generate a list of kernels, but not actually emit files at config stage -execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.") -endif() - -execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.") -endif() - -# Generate the files for both fwd and bwd -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") -endif() - -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to generate BWD kernels.") -endif() - -# Change make_kernel to make_kernel_pt for fwd -execute_process( - COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt" - RESULT_VARIABLE ret) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass") -endif() - -# Change make_kernel to make_kernel_pt for bwd -execute_process( - COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt" - RESULT_VARIABLE ret) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the bwd pass") -endif() - -# Change file extensions to .hip -execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done" - RESULT_VARIABLE ret -) - -if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to change the generated instances extensions from .cpp to .hpp") -endif() diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 83aa67084667..c9b520ef9499 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -929,6 +929,10 @@ if(USE_ROCM) if(USE_FLASH_ATTENTION) target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) endif() +# link CK library + if(USE_CK_FLASH_ATTENTION) + target_link_libraries(torch_hip PRIVATE __ck_lib) + endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake # TODO: Not totally sure if this is live or not diff --git a/cmake/External/ck.cmake b/cmake/External/ck.cmake new file mode 100644 index 000000000000..ac2165a701d1 --- /dev/null +++ b/cmake/External/ck.cmake @@ -0,0 +1,43 @@ +# +# create INTERFACE target for CK library +# + +# get CK commit hash +execute_process( + COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel + RESULT_VARIABLE result + OUTPUT_VARIABLE submodule_status + ERROR_VARIABLE submodule_status_error + OUTPUT_STRIP_TRAILING_WHITESPACE + ) +if(result EQUAL 0) + string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status}) + # extract first 8 characters of the commit hash + string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash) +else() + message(FATAL_ERROR "Failed to get submodule status for composable_kernel.") +endif() + +# get ROCm version from LoadHIP.cmake +include(${CMAKE_SOURCE_DIR}/cmake/public/LoadHIP.cmake) + +# full path for CK library on compute-artifactory.amd.com +set(url "https://compute-artifactory.amd.com/artifactory/rocm-generic-local") +set(ck_lib_full_path "${url}/torch_ck_gen_lib/ck_${ck_commit_hash}/rocm_${ROCM_VERSION_DEV}/libck_kernels.so") + +# set destination +set(destination "${CMAKE_SOURCE_DIR}/torch/lib/libck_kernels.so") + +# download CK library +file(DOWNLOAD ${ck_lib_full_path} ${destination} SHOW_PROGRESS RESULT_VARIABLE download_status) +if(NOT download_status) + message(STATUS "Downloaded CK library successfully.") +else() + message(FATAL_ERROR "Failed to download the CK library from ${SOURCE_URL}.") +endif() + +# create INTERFACE target +add_library(__ck_lib INTERFACE) + +# specify path to CK library +target_link_libraries(__ck_lib INTERFACE ${destination})