Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,5 @@ compile_commands.json
**/profiler_outputs/
**/times.csv
tensor_dumps/
aiter/
transformer_engine/build_info.txt
transformer_engine/common/util/hip_nvml.*
transformer_engine/aiter/
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 273 files
30 changes: 4 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@


from setuptools.command.build_ext import build_ext as BuildExtension
from setuptools.command.develop import develop as _develop

os.environ["NVTE_PROJECT_BUILDING"] = "1"

Expand All @@ -48,26 +47,6 @@
if not rocm_build():
archs = cuda_archs()

# A custom develop command only used for ROCm builds
class develop(_develop):
def run(self):
super().run()
if (
int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and
int(os.getenv("NVTE_FUSED_ATTN", "1"))
):
# Ensure that the AITER ASM kernels are properly available at runtime
# by creating a symlink to them. This is only necessary for editable
# mode since our C++ code assumes the AITER ASM kernel paths relative
# to trasnformer_engine.so, which is different in editable installs.
project_dir = Path(__file__).parent
asm_src_dir = project_dir / 'transformer_engine' / 'aiter'
# Must be synced with
# TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
asm_target_dir = project_dir / 'aiter'
if asm_src_dir.is_dir() and not asm_target_dir.is_dir():
asm_target_dir.symlink_to(asm_src_dir)

class TimedBdist(bdist_wheel):
"""Helper class to measure build time"""

Expand All @@ -89,7 +68,7 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}")
if os.getenv("NVTE_CK_FUSED_ATTN_PATH"):
ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH"))
cmake_flags.append(f"-DCK_FUSED_ATTN_PATH={ck_path}")
cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}")
if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF")
if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
Expand Down Expand Up @@ -192,14 +171,14 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
with open("README.rst", encoding="utf-8") as f:
long_description = f.read()

cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
# Settings for building top level empty package for dependency management.
if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
assert bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
te_cuda_vers = "rocm" if rocm_build() else "cu12"
ext_modules = []
cmdclass = {}
package_data = {}
include_package_data = False
setup_requires = []
Expand All @@ -211,8 +190,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
else:
setup_requires, install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
if rocm_build():
cmdclass["develop"] = develop
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
package_data = {"": ["VERSION.txt"]}
include_package_data = True
extras_require = {"test": test_requires}
Expand Down Expand Up @@ -255,7 +233,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
long_description=long_description,
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass=cmdclass,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8, <3.13",
classifiers=[
"Programming Language :: Python :: 3.8",
Expand Down
13 changes: 1 addition & 12 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,7 @@ else()
endif()

if(USE_FUSED_ATTN_CK)
if(NOT DEFINED CK_FUSED_ATTN_PATH)
set(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} CACHE STRING "ck float to bf16 conversion rounding")
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
else()
# Use CK built during initial TE building/installation
# When only need rebuild TE library itself
unset(CK_FUSED_ATTN_LIB CACHE)
find_library(CK_FUSED_ATTN_LIB NAMES ck_fused_attn PATHS ${CK_FUSED_ATTN_PATH}/lib REQUIRED NO_DEFAULT_PATH)
add_library( ck_fused_attn STATIC IMPORTED )
set_target_properties( ck_fused_attn PROPERTIES IMPORTED_LOCATION ${CK_FUSED_ATTN_LIB} )
target_include_directories(ck_fused_attn INTERFACE ${CK_FUSED_ATTN_PATH}/include)
endif()
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
endif()

find_package(hip)
Expand Down
176 changes: 29 additions & 147 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

#TODO: compile to a shared library
cmake_minimum_required(VERSION 3.28)
set(CMAKE_CXX_STANDARD 20)
#TODO: remove after figuring out how to install clang-scan-deps
set(CMAKE_CXX_SCAN_FOR_MODULES OFF)
cmake_minimum_required(VERSION 3.21)
set(CMAKE_CXX_STANDARD 17)
project(ck_fused_attn LANGUAGES HIP CXX)

# remove files that should be regenerated
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp ${CMAKE_CURRENT_BINARY_DIR}/gen_src/blob_list.txt)

# create gen_src and gen_src/tmp directories if needed
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp)
set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE")

set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter")
set(__AITER_TEST_DIR "${__AITER_SOURCE_DIR}/op_tests/cpp/mha")
set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel")

# so far, there are only gfx942 and gfx950 v3 kernels
Expand All @@ -37,158 +32,41 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}")
list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR)
set(ENV{GPU_ARCHS} "${V3_ASM_ARCHS_STR}")

# generate v2 (CK) kernels
# fwd kernels list
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt --receipt 600
)
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt --receipt 600
)
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api batch_prefill --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt --receipt 600
)

# bwd kernels list
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt --receipt 600
)

file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt FMHA_FWD_SPLITKV_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt FMHA_FWD_BATCH_PREFILL_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)

# generate the actual fwd kernel cpp files
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api batch_prefill --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

# generate the aiter fwd interface cpp file
execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_fwd_generate.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 5
)

# generate the actual bwd kernel cpp files
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

# generate the aiter bwd interface cpp file
execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py
--filter *@*_ndeterministic@*_nbias*_dropout*_ndeterministic* --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
)

execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_bwd_generate.py
--receipt 3 --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
)

# generate fwd/bwd v3 kernels for each requested rocm arch
foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS)
execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_fwd/codegen.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
)
if(NOT DEFINED AITER_MHA_PATH)
# delete the existing aiter/jit/build dir for a clean build
file(REMOVE_RECURSE "${__AITER_SOURCE_DIR}/aiter/jit/build")
# compile the libmha_fwd.so and libmha_bwd.so
set(ENV{AITER_LOG_MORE} 1)
# fp32 to bf16 cvt env still required for MI300X
set(ENV{CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT} ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT})
execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/hsa/${CK_TARGET_ARCH}/fmha_v3_bwd/codegen.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
COMMAND python3 ${__AITER_TEST_DIR}/compile.py
)
endforeach()
# libmha_fwd.so and libmha_bwd.so will be under 3rdparty/aiter/op_tests/cpp/mha
set(__AITER_MHA_PATH ${__AITER_TEST_DIR})
else()
# use pre-built libmha_fwd.so libmha_bwd.so
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
endif()

set(ck_fused_attn_SOURCES)
list(APPEND ck_fused_attn_SOURCES
src/ck_fused_attn_fwd.cpp
src/ck_fused_attn_bwd.cpp
src/ck_fused_attn_utils.cpp)

foreach(blob ${FMHA_FWD_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_GEN_BLOBS})

foreach(blob ${FMHA_FWD_SPLITKV_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_SPLITKV_GEN_BLOBS})

foreach(blob ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS})

foreach(blob ${FMHA_BWD_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_BWD_GEN_BLOBS})

# add generated cpp files into ck_fused_attn_sources
set(MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_bwd.cpp")
set(MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_fwd.cpp")

file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_BWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_BWD_SRC} ONLY_IF_DIFFERENT)

file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_FWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_FWD_SRC} ONLY_IF_DIFFERENT)

list(APPEND ck_fused_attn_SOURCES ${MHA_BWD_SRC} ${MHA_FWD_SRC})

foreach(CK_TARGET_ARCH IN LISTS V3_ASM_ARCHS)
set(ASM_MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_fwd_v3_${CK_TARGET_ARCH}.cpp")
set(ASM_MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_bwd_v3_${CK_TARGET_ARCH}.cpp")

file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_BWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_BWD_SRC} ONLY_IF_DIFFERENT)

file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_FWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_FWD_SRC} ONLY_IF_DIFFERENT)
list(APPEND ck_fused_attn_SOURCES ${ASM_MHA_BWD_SRC} ${ASM_MHA_FWD_SRC})
endforeach()

# remove all previously generated temporary files
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp)

message(STATUS "Found the following fused attention files:")
foreach(file ${ck_fused_attn_SOURCES})
message(STATUS " ${file}")
endforeach()

add_library(ck_fused_attn STATIC ${ck_fused_attn_SOURCES})
add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES})
set(CK_FUSED_ATTN_COMPILE_OPTIONS)
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS
-DCK_TILE_FMHA_FWD_FAST_EXP2=1 -DCK_TILE_FMHA_FWD_SPLITKV_API=1-DCK_TILE_FMHA_FWD_APPENDKV_API=0
-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}
-fgpu-flush-denormals-to-zero -ftemplate-backtrace-limit=0 -fPIC
-Wno-undefined-func-template -Wno-float-equal -Wno-gnu-line-marker -Wunused-variable -Wuninitialized
"SHELL:-mllvm -enable-post-misched=0" "SHELL:-mllvm -amdgpu-early-inline-all=true"
"SHELL:-mllvm -amdgpu-function-calls=false" "SHELL:-mllvm -amdgpu-coerce-illegal-types=1"
"SHELL:-mllvm --amdgpu-kernarg-preload-count=16")
-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT})

foreach(CK_TARGET_ARCH IN LISTS CMAKE_HIP_ARCHITECTURES)
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${CK_TARGET_ARCH})
foreach(ARCH IN LISTS V3_ASM_ARCHS)
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${ARCH})
endforeach()

set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include")
Expand Down Expand Up @@ -216,18 +94,22 @@ target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE
target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR})

find_package(hip)
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64)
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so)
target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS})
target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS})
set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")

install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
# copy v3 kernels to destination
foreach(ARCH IN LISTS V3_ASM_ARCHS)
install(DIRECTORY
${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_fwd
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
PATTERN "codegen.py" EXCLUDE)
install(DIRECTORY
${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_bwd
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/aiter/${ARCH}/
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
PATTERN "codegen.py" EXCLUDE)
endforeach()

Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,8 @@ hipError_t ck_attn_varlen_bwd(
cu_seqlen_q_ptr,//cu_seqlen_q
cu_seqlen_kv_ptr,//cu_seqlen_kv
nullptr, /* seqlen_k_ptr */
0, //seqlen_q, unused in group mode
0, //seqlen_kv, unused in group mode
max_seqlen_q, //seqlen_q, unused in group mode
max_seqlen_k, //seqlen_kv, unused in group mode
batch,
max_seqlen_q,
max_seqlen_k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,13 @@ hipError_t ck_attn_fwd(
nullptr,//rand_val_ptr
lse_ptr,
o_ptr,
nullptr,//cu_seqlen_q
nullptr,//cu_seqlen_kv
nullptr, /* seqlen_k_ptr */
nullptr, //cu_seqlen_q
nullptr, //cu_seqlen_kv
nullptr, //seqstart_q_ptr
nullptr, //seqstart_k_ptr
nullptr, //seqlen_k_ptr
nullptr, //seqstart_padded_q_ptr
nullptr, //seqstart_padded_k_ptr
max_seqlen_q,
max_seqlen_k,
batch,
Expand Down Expand Up @@ -308,6 +312,7 @@ hipError_t ck_attn_varlen_fwd(
ck_tile::index_t nhead_k = hg;
ck_tile::index_t hdim_v = d_v;
ck_tile::index_t max_seqlen_q = s_q;
ck_tile::index_t max_seqlen_kv = s_kv;

float scale_s = scaling_factor;
float scale_p = 1.f;
Expand Down Expand Up @@ -379,11 +384,15 @@ hipError_t ck_attn_varlen_fwd(
nullptr,//rand_val_ptr
lse_thd_ptr,
o_ptr,
cu_seqlen_q_ptr,//cu_seqlen_q
cu_seqlen_kv_ptr,//cu_seqlen_kv
nullptr, /* seqlen_k_ptr */
0, //seqlen_q, unused in group mode
0, //seqlen_kv, unused in group mode
nullptr, //cu_seqlen_q
nullptr, //cu_seqlen_kv
cu_seqlen_q_ptr, //seqstart_q_ptr
cu_seqlen_kv_ptr, //seqstart_k_ptr
nullptr, //seqlen_k_ptr
nullptr, //seqstart_padded_q_ptr
nullptr, //seqstart_padded_k_ptr
max_seqlen_q, //seqlen_q, unused in group mode
max_seqlen_kv, //seqlen_kv, unused in group mode
batch,
max_seqlen_q,
hdim_q,
Expand Down
Loading