diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index a62b028fd4ff9..fed6031092cff 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1524,15 +1524,19 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER; cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER; -#if defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) +#if defined(USE_ROCM) +#if defined(HIPBLASLT_OUTER_VEC) + // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F +#elif defined(HIPBLASLT_VEC_EXT) if (use_rowwise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } #else - // rowwise isn't supported using cublaslt or older hipblaslt - TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); + // rowwise isn't supported using older hipblaslt + TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt"); #endif +#endif // defined(USE_ROCM) computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); if (result_scale_ptr != nullptr) { @@ -1572,7 +1576,15 @@ void scaled_gemm( #else TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above"); #endif // CUDA_VERSION >= 12080 - } + } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) { +#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); +#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 size_t workspaceSize = _getWorkspaceSize(); auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 640ff7331c61c..701b8e11aed1d 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -522,6 +522,12 @@ class HipblasltGemmOp : public Callable { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); } +#ifdef HIPBLASLT_OUTER_VEC + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + } +#endif } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 66bb1c5ec285c..2c840997d5fc4 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1056,7 +1056,7 @@ ScalingType get_scaling_type( if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { #if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \ - (defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)) + (defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC))) TORCH_CHECK( scale_a.is_contiguous() && scale_b.is_contiguous(), "Both scale_a and scale_b must be contiguous for RowWise scaling."); diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ba10d685a9f65..4abefaac8df76 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1059,6 +1059,9 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP) list(APPEND HIP_CXX_FLAGS -std=c++17) list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2) + if(HIPBLASLT_OUTER_VEC) + list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_OUTER_VEC) + endif() if(HIPBLASLT_VEC_EXT) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT) endif() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 58c74ddda3505..14308ff34c1f3 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -178,6 +178,21 @@ if(HIP_FOUND) set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") + # check whether hipblaslt provides HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F + set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_outer_vec.cc") + file(WRITE ${file} "" + "#define LEGACY_HIPBLAS_DIRECT\n" + "#include \n" + "int main() {\n" + " hipblasLtMatmulMatrixScale_t attr = HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;\n" + " return 0;\n" + "}\n" + ) + try_compile(hipblaslt_compile_result_outer_vec ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output_outer_vec) + # check whether hipblaslt provides HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_vec_ext.cc") file(WRITE ${file} "" @@ -191,15 +206,21 @@ if(HIP_FOUND) try_compile(hipblaslt_compile_result_vec_ext ${PROJECT_RANDOM_BINARY_DIR} ${file} CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ - OUTPUT_VARIABLE hipblaslt_compile_output) - if(hipblaslt_compile_result_vec_ext) + OUTPUT_VARIABLE hipblaslt_compile_output_vec_ext) + + if(hipblaslt_compile_result_outer_vec) + set(HIPBLASLT_OUTER_VEC ON) + set(HIPBLASLT_VEC_EXT OFF) + message("hipblaslt is using scale pointer outer vec") + elseif(hipblaslt_compile_result_vec_ext) + set(HIPBLASLT_OUTER_VEC OFF) set(HIPBLASLT_VEC_EXT ON) - #message("hipblaslt is using scale pointer vec ext: ${hipblaslt_compile_output}") message("hipblaslt is using scale pointer vec ext") else() + set(HIPBLASLT_OUTER_VEC OFF) set(HIPBLASLT_VEC_EXT OFF) - message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output}") - #message("hipblaslt is NOT using scale pointer vec ext") + message("hipblaslt is NOT using scale pointer outer vec: ${hipblaslt_compile_output_outer_vec}") + message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output_vec_ext}") endif() endif() endif() diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 1c9f90346331c..17ccc59ff3acf 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -7320,6 +7320,9 @@ ("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),