diff --git a/.gitignore b/.gitignore index 44de0a19e..874eed018 100644 --- a/.gitignore +++ b/.gitignore @@ -52,7 +52,5 @@ compile_commands.json **/profiler_outputs/ **/times.csv tensor_dumps/ -aiter/ transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* -transformer_engine/aiter/ diff --git a/3rdparty/aiter b/3rdparty/aiter index a2ca1b460..74e71eb8e 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit a2ca1b460f097a309ee5a128c7454b1c419dc331 +Subproject commit 74e71eb8ee8a663d5e33c0cfd8b4dad7708ae84b diff --git a/setup.py b/setup.py index 41893644c..91817d56e 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,6 @@ from setuptools.command.build_ext import build_ext as BuildExtension -from setuptools.command.develop import develop as _develop os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -48,26 +47,6 @@ if not rocm_build(): archs = cuda_archs() -# A custom develop command only used for ROCm builds -class develop(_develop): - def run(self): - super().run() - if ( - int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and - int(os.getenv("NVTE_FUSED_ATTN", "1")) - ): - # Ensure that the AITER ASM kernels are properly available at runtime - # by creating a symlink to them. This is only necessary for editable - # mode since our C++ code assumes the AITER ASM kernel paths relative - # to trasnformer_engine.so, which is different in editable installs. - project_dir = Path(__file__).parent - asm_src_dir = project_dir / 'transformer_engine' / 'aiter' - # Must be synced with - # TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp - asm_target_dir = project_dir / 'aiter' - if asm_src_dir.is_dir() and not asm_target_dir.is_dir(): - asm_target_dir.symlink_to(asm_src_dir) - class TimedBdist(bdist_wheel): """Helper class to measure build time""" @@ -89,7 +68,7 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}") if os.getenv("NVTE_CK_FUSED_ATTN_PATH"): ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH")) - cmake_flags.append(f"-DCK_FUSED_ATTN_PATH={ck_path}") + cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}") if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: @@ -192,7 +171,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: with open("README.rst", encoding="utf-8") as f: long_description = f.read() - cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} # Settings for building top level empty package for dependency management. if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): assert bool( @@ -200,6 +178,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ), "NVTE_RELEASE_BUILD env must be set for metapackage build." te_cuda_vers = "rocm" if rocm_build() else "cu12" ext_modules = [] + cmdclass = {} package_data = {} include_package_data = False setup_requires = [] @@ -211,8 +190,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: else: setup_requires, install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] - if rocm_build(): - cmdclass["develop"] = develop + cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} package_data = {"": ["VERSION.txt"]} include_package_data = True extras_require = {"test": test_requires} @@ -255,7 +233,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass=cmdclass, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8, <3.13", classifiers=[ "Programming Language :: Python :: 3.8", diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a9e2e056e..f70c9f8bb 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -351,18 +351,7 @@ else() endif() if(USE_FUSED_ATTN_CK) - if(NOT DEFINED CK_FUSED_ATTN_PATH) - set(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} CACHE STRING "ck float to bf16 conversion rounding") - add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn) - else() - # Use CK built during initial TE building/installation - # When only need rebuild TE library itself - unset(CK_FUSED_ATTN_LIB CACHE) - find_library(CK_FUSED_ATTN_LIB NAMES ck_fused_attn PATHS ${CK_FUSED_ATTN_PATH}/lib REQUIRED NO_DEFAULT_PATH) - add_library( ck_fused_attn STATIC IMPORTED ) - set_target_properties( ck_fused_attn PROPERTIES IMPORTED_LOCATION ${CK_FUSED_ATTN_LIB} ) - target_include_directories(ck_fused_attn INTERFACE ${CK_FUSED_ATTN_PATH}/include) - endif() + add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn) endif() find_package(hip) diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 2a2afa328..c44a930e6 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -1,20 +1,15 @@ # Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT -#TODO: compile to a shared library -cmake_minimum_required(VERSION 3.28) -set(CMAKE_CXX_STANDARD 20) -#TODO: remove after figuring out how to install clang-scan-deps -set(CMAKE_CXX_SCAN_FOR_MODULES OFF) +cmake_minimum_required(VERSION 3.21) +set(CMAKE_CXX_STANDARD 17) project(ck_fused_attn LANGUAGES HIP CXX) -# remove files that should be regenerated -file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp ${CMAKE_CURRENT_BINARY_DIR}/gen_src/blob_list.txt) -# create gen_src and gen_src/tmp directories if needed -file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp) +set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter") +set(__AITER_TEST_DIR "${__AITER_SOURCE_DIR}/op_tests/cpp/mha") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") # so far, there are only gfx942 and gfx950 v3 kernels @@ -37,82 +32,22 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}") list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR) set(ENV{GPU_ARCHS} "${V3_ASM_ARCHS_STR}") -# generate v2 (CK) kernels -# fwd kernels list -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt --receipt 600 -) -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt --receipt 600 -) -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api batch_prefill --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt --receipt 600 -) - -# bwd kernels list -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt --receipt 600 -) - -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt FMHA_FWD_SPLITKV_GEN_BLOBS) -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt FMHA_FWD_BATCH_PREFILL_GEN_BLOBS) -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) - -# generate the actual fwd kernel cpp files -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api batch_prefill --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -# generate the aiter fwd interface cpp file -execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_fwd_generate.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 5 -) - -# generate the actual bwd kernel cpp files -execute_process( - COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py - --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600 -) - -# generate the aiter bwd interface cpp file -execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py - --filter *@*_ndeterministic@*_nbias*_dropout*_ndeterministic* --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp -) - -execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_bwd_generate.py - --receipt 3 --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp -) - -# generate fwd/bwd v3 kernels for each requested rocm arch -foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS) - execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_fwd/codegen.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp - ) +if(NOT DEFINED AITER_MHA_PATH) + # delete the existing aiter/jit/build dir for a clean build + file(REMOVE_RECURSE "${__AITER_SOURCE_DIR}/aiter/jit/build") + # compile the libmha_fwd.so and libmha_bwd.so + set(ENV{AITER_LOG_MORE} 1) + # fp32 to bf16 cvt env still required for MI300X + set(ENV{CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT} ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) execute_process( - COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_bwd/codegen.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp + COMMAND python3 ${__AITER_TEST_DIR}/compile.py ) -endforeach() + # libmha_fwd.so and libmha_bwd.so will be under 3rdparty/aiter/op_tests/cpp/mha + set(__AITER_MHA_PATH ${__AITER_TEST_DIR}) +else() + # use pre-built libmha_fwd.so libmha_bwd.so + set(__AITER_MHA_PATH ${AITER_MHA_PATH}) +endif() set(ck_fused_attn_SOURCES) list(APPEND ck_fused_attn_SOURCES @@ -120,75 +55,18 @@ list(APPEND ck_fused_attn_SOURCES src/ck_fused_attn_bwd.cpp src/ck_fused_attn_utils.cpp) -foreach(blob ${FMHA_FWD_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_GEN_BLOBS}) - -foreach(blob ${FMHA_FWD_SPLITKV_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_SPLITKV_GEN_BLOBS}) - -foreach(blob ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS}) - -foreach(blob ${FMHA_BWD_GEN_BLOBS}) - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT) -endforeach() -list(APPEND ck_fused_attn_SOURCES ${FMHA_BWD_GEN_BLOBS}) - -# add generated cpp files into ck_fused_attn_sources -set(MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_bwd.cpp") -set(MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_fwd.cpp") - -file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_BWD_SRC}) -file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_BWD_SRC} ONLY_IF_DIFFERENT) - -file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_FWD_SRC}) -file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_FWD_SRC} ONLY_IF_DIFFERENT) - -list(APPEND ck_fused_attn_SOURCES ${MHA_BWD_SRC} ${MHA_FWD_SRC}) - -foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS) - set(ASM_MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_fwd_v3_${CK_TARGET_ARCH}.cpp") - set(ASM_MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_bwd_v3_${CK_TARGET_ARCH}.cpp") - - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_BWD_SRC}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_BWD_SRC} ONLY_IF_DIFFERENT) - - file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_FWD_SRC}) - file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_FWD_SRC} ONLY_IF_DIFFERENT) - list(APPEND ck_fused_attn_SOURCES ${ASM_MHA_BWD_SRC} ${ASM_MHA_FWD_SRC}) -endforeach() - -# remove all previously generated temporary files -file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp) - message(STATUS "Found the following fused attention files:") foreach(file ${ck_fused_attn_SOURCES}) message(STATUS " ${file}") endforeach() -add_library(ck_fused_attn STATIC ${ck_fused_attn_SOURCES}) +add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES}) set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS - -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -DCK_TILE_FMHA_FWD_SPLITKV_API=1-DCK_TILE_FMHA_FWD_APPENDKV_API=0 - -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} - -fgpu-flush-denormals-to-zero -ftemplate-backtrace-limit=0 -fPIC - -Wno-undefined-func-template -Wno-float-equal -Wno-gnu-line-marker -Wunused-variable -Wuninitialized - "SHELL:-mllvm -enable-post-misched=0" "SHELL:-mllvm -amdgpu-early-inline-all=true" - "SHELL:-mllvm -amdgpu-function-calls=false" "SHELL:-mllvm -amdgpu-coerce-illegal-types=1" - "SHELL:-mllvm --amdgpu-kernarg-preload-count=16") + -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) -foreach(CK_TARGET_ARCH IN LISTS CMAKE_HIP_ARCHITECTURES) - list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${CK_TARGET_ARCH}) +foreach(ARCH IN LISTS V3_ASM_ARCHS) + list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${ARCH}) endforeach() set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include") @@ -216,18 +94,22 @@ target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR}) find_package(hip) -list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64) +list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so) target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS}) +set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") +install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) +install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) # copy v3 kernels to destination foreach(ARCH IN LISTS V3_ASM_ARCHS) install(DIRECTORY ${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_fwd - DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/ PATTERN "codegen.py" EXCLUDE) install(DIRECTORY ${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_bwd - DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/ PATTERN "codegen.py" EXCLUDE) endforeach() + diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 840db7b86..2b717ace0 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -920,8 +920,8 @@ hipError_t ck_attn_varlen_bwd( cu_seqlen_q_ptr,//cu_seqlen_q cu_seqlen_kv_ptr,//cu_seqlen_kv nullptr, /* seqlen_k_ptr */ - 0, //seqlen_q, unused in group mode - 0, //seqlen_kv, unused in group mode + max_seqlen_q, //seqlen_q, unused in group mode + max_seqlen_k, //seqlen_kv, unused in group mode batch, max_seqlen_q, max_seqlen_k, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 2829175ab..c87a3db6c 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -209,9 +209,13 @@ hipError_t ck_attn_fwd( nullptr,//rand_val_ptr lse_ptr, o_ptr, - nullptr,//cu_seqlen_q - nullptr,//cu_seqlen_kv - nullptr, /* seqlen_k_ptr */ + nullptr, //cu_seqlen_q + nullptr, //cu_seqlen_kv + nullptr, //seqstart_q_ptr + nullptr, //seqstart_k_ptr + nullptr, //seqlen_k_ptr + nullptr, //seqstart_padded_q_ptr + nullptr, //seqstart_padded_k_ptr max_seqlen_q, max_seqlen_k, batch, @@ -308,6 +312,7 @@ hipError_t ck_attn_varlen_fwd( ck_tile::index_t nhead_k = hg; ck_tile::index_t hdim_v = d_v; ck_tile::index_t max_seqlen_q = s_q; + ck_tile::index_t max_seqlen_kv = s_kv; float scale_s = scaling_factor; float scale_p = 1.f; @@ -379,11 +384,15 @@ hipError_t ck_attn_varlen_fwd( nullptr,//rand_val_ptr lse_thd_ptr, o_ptr, - cu_seqlen_q_ptr,//cu_seqlen_q - cu_seqlen_kv_ptr,//cu_seqlen_kv - nullptr, /* seqlen_k_ptr */ - 0, //seqlen_q, unused in group mode - 0, //seqlen_kv, unused in group mode + nullptr, //cu_seqlen_q + nullptr, //cu_seqlen_kv + cu_seqlen_q_ptr, //seqstart_q_ptr + cu_seqlen_kv_ptr, //seqstart_k_ptr + nullptr, //seqlen_k_ptr + nullptr, //seqstart_padded_q_ptr + nullptr, //seqstart_padded_k_ptr + max_seqlen_q, //seqlen_q, unused in group mode + max_seqlen_kv, //seqlen_kv, unused in group mode batch, max_seqlen_q, hdim_q, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 72696fbd9..b38249f5b 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -557,6 +557,7 @@ void fused_attn_ck_fwd_impl( nvte_log_ck_config = true; } bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0); + bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD; // extract the qkv and o storage bytes to allocate buffer for padding removing