Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions ci/core.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,25 @@ if [ $rc -ne 0 ]; then
exit $rc
fi

echo ===== Run non GEMM tests =====
ctest --test-dir build -j4 --output-on-failure -E "OperatorTest/GEMMTestSuite"
test $? -eq 0 || test_run_error "non-GEMM"
check_test_filter "nongemm"
if [ $? -eq 0 ]; then
echo ===== Run non GEMM tests =====
ctest --test-dir build -j4 --output-on-failure -E "OperatorTest/GEMMTestSuite"
test $? -eq 0 || test_run_error "non-GEMM"
fi

for _gemm in hipblaslt rocblas; do
configure_gemm_env $_gemm || continue
_exclude=""
if [ $_gemm = "hipblaslt" ]; then
_exclude="-E Test(.*bf16/.*X.X1|.*fp8.*fp16/.*X1X0|.*fp8.*X.X1|.*fp8/|.*bf8/)"
fi
echo ===== Run GEMM $_gemm tests =====
ctest --test-dir build -j4 --output-on-failure -R "OperatorTest/GEMMTestSuite" $_exclude
test $? -eq 0 || test_run_error "GEMM $_gemm"
check_test_filter $_gemm
if [ $? -eq 0 ]; then
echo ===== Run GEMM $_gemm tests =====
ctest --test-dir build -j4 --output-on-failure -R "OperatorTest/GEMMTestSuite" $_exclude
test $? -eq 0 || test_run_error "GEMM $_gemm"
fi
done

return_run_results
20 changes: 19 additions & 1 deletion tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
}
#endif

Tensor Workspace({ 33554432 }, DType::kByte);
size_t workspace_size = 33554432;
#ifdef __HIP_PLATFORM_AMD__
if (prop.major == 9 && prop.minor == 5) {
workspace_size = 67108864;
}
#endif
Tensor Workspace({ workspace_size }, DType::kByte);

//perform the gemm in GPU
nvte_cublas_gemm(A.data(),
Expand Down Expand Up @@ -212,6 +218,18 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
if (dtype == DType::kFloat32) {
atol = 1e-5;
}
#ifdef __HIP_PLATFORM_AMD__
if (prop.major == 9 && prop.minor == 5)
{
// relax for certain gemm with hipblaslt
if (!isFp8Type(dtype) && (isFp8Type(atype) or isFp8Type(btype))) {
atol = 5e-4;
rtol = 5e-3;
} else if (dtype == DType::kFloat32) {
rtol = 1e-5;
}
}
#endif
compareResults("D", D, ref_D.get(), atol, rtol);

if(use_gelu){
Expand Down
5 changes: 3 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,8 +1972,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
if IS_HIP_EXTENSION:
if use_hipblaslt():
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16) and is_mi308():
# mi308 hipblaslt precision issue
if dtype in (torch.float16, torch.bfloat16):
# On some GPUs hipblaslt results for SBHD and BSHD are different
# that results in lower final result precision
tols["atol"] = 2e-3
_, use_aotriton, use_ck = rocm_attn_backend()
if use_aotriton and not use_ck:
Expand Down
93 changes: 51 additions & 42 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ else()
set(__AOTRITON_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton")
set(__AOTRITON_SUFFIX "_TEprivate")
if(NOT DEFINED AOTRITON_PATH)
# # Install aotriton fused attn
# Install aotriton fused attn
if(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(AOTRITON_NOIMAGE_MODE OFF)
else()
Expand All @@ -271,51 +271,60 @@ else()
foreach(X IN LISTS CMAKE_HIP_ARCHITECTURES)
set(key ${X})
string(APPEND key "_key")
string(APPEND aotriton_target_gpus ${${key}})
string(APPEND aotriton_target_gpus "|")
set(gpu ${${key}})
if (gpu)
string(APPEND aotriton_target_gpus "${gpu}|")
else()
message(WARNING "AOTriton building is not supported for ${X}")
endif()
endforeach()
endmacro()
translate_arch_to_gpu_names(aotriton_target_gpus)
include(ExternalProject)
ExternalProject_Add(aotriton_external
SOURCE_DIR ../../3rdparty/aotriton
LIST_SEPARATOR |
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DTARGET_GPUS=${aotriton_target_gpus}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
)
add_library(aotriton INTERFACE)
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(__AOTRITON_VER "0.8.2b")
set(__AOTRITON_SHA256 "66445e6b0209b9f4080743b839cc9d424054dc5c8d07363f9f27f109231c324a")
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/"
"${__AOTRITON_VER}/aotriton-"
"${__AOTRITON_VER}-manylinux_2_28"
"_x86_64-rocm6.2"
"-shared.tar.gz")
ExternalProject_Add(aotriton_images
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images"
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images")
add_dependencies(aotriton aotriton_images)
if (NOT aotriton_target_gpus)
set(USE_FUSED_ATTN_AOTRITON FALSE)
message(WARNING "Disable AOTriton building because no supported GPU targets found")
else()
include(ExternalProject)
ExternalProject_Add(aotriton_external
SOURCE_DIR ../../3rdparty/aotriton
LIST_SEPARATOR |
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DTARGET_GPUS=${aotriton_target_gpus}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NAME_SUFFIX=${__AOTRITON_SUFFIX}
-DAOTRITON_NOIMAGE_MODE=${AOTRITON_NOIMAGE_MODE}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so"
)
add_library(aotriton INTERFACE)
add_dependencies(aotriton aotriton_external)
target_link_libraries(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton${__AOTRITON_SUFFIX}_v2.so)
target_include_directories(aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
if(NOT USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS)
set(__AOTRITON_VER "0.8.2b")
set(__AOTRITON_SHA256 "66445e6b0209b9f4080743b839cc9d424054dc5c8d07363f9f27f109231c324a")
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/"
"${__AOTRITON_VER}/aotriton-"
"${__AOTRITON_VER}-manylinux_2_28"
"_x86_64-rocm6.2"
"-shared.tar.gz")
ExternalProject_Add(aotriton_images
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball/lib/aotriton.images"
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images")
add_dependencies(aotriton aotriton_images)
endif()
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE)
endif()
install(DIRECTORY
${__AOTRITON_INSTALL_DIR}/lib
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine
PATTERN "cmake" EXCLUDE
PATTERN "libaotriton${__AOTRITON_SUFFIX}_v2.so" EXCLUDE)
else()
# Use aotriton built during initial TE building/installation
# When only need rebuild TE library itself
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/amd_detail/hip_float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct te_hip_fp8_e4m3 {

__host__ __device__ operator float() const { return data.operator float(); }

__host__ __device__ te_hip_fp8_e4m3(const float& v) { data = v;}
__host__ __device__ te_hip_fp8_e4m3(const float& v): data(v) {}
};
static_assert(sizeof(te_hip_fp8_e4m3) == 1, "Size mismatch");

Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/recipe/delayed_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ inline float fp8_dtype_max(DType dtype) {
case DType::kFloat8E4M3:
#ifndef __HIP_PLATFORM_AMD__
return 448;
#else
#elif HIP_VERSION >= 60300000
return te_fp8_fnuz() ? 240 : 448;
#else
return 240; // default to true for older versions compatibility
#endif
case DType::kFloat8E5M2:
return 57344;
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/util/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ extern "C" bool nvte_uses_fp8_fnuz()
#if HIP_VERSION >= 60300000
return te_fp8_fnuz();
#endif
return true; // default to true for older versions that only support
return true; // default to true for older versions compatibility
}
#endif
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _get_supported_versions(version_min, version_max):
_flash_attn_is_installed = False
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
_flash_attn_max_version = PkgVersion("2.7.3")
_flash_attn_max_version = PkgVersion("2.7.4.post1")
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
Expand Down