diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index c1a842653f4c4..88418551601d3 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -13,12 +13,13 @@ // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also // added bf16 support -#if !defined(USE_ROCM) && !defined(_MSC_VER) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) #include #endif #ifdef USE_ROCM // until hipblas has an API to accept flags, we must use rocblas here +#include #include #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) @@ -64,6 +65,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) // until we use hiblas v2 // hipify correctly maps things like CUDA_R_16F to HIP_R_16F, // however hipblas v1 is still using its custom type +#ifndef HIPBLAS_V2 #define HIP_R_16F HIPBLAS_R_16F #define HIP_R_32F HIPBLAS_R_32F #define HIP_R_64F HIPBLAS_R_64F @@ -81,6 +83,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) #define HIP_R_16BF HIPBLAS_R_16B #define HIP_C_16BF HIPBLAS_C_16B #endif +#endif #define CUDABLAS_POSINT_CHECK(FD, X) \ TORCH_CHECK( \ @@ -167,6 +170,7 @@ static void _cublasAdjustLdLevel3( } } +#ifndef USE_ROCM uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; @@ -176,18 +180,25 @@ uint32_t _getAlignment(uintptr_t address) { } } } +#endif static size_t _parseChosenWorkspaceSize() { const char * val = getenv("CUBLASLT_WORKSPACE_SIZE"); +#ifdef USE_ROCM + if (!val) { + // accept either env var + val = getenv("HIPBLASLT_WORKSPACE_SIZE"); + } +#endif size_t workspace_size = 1024; /* default size in KiB according to #73328 */ if (val) { try { workspace_size = std::stoi(val); } catch(std::invalid_argument const& e) { - TORCH_WARN("invalid CUBLAS_LT_WORKSPACE_SIZE,", + TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,", " using default workspace size of ", workspace_size, " bytes."); } catch(std::out_of_range const& e) { - TORCH_WARN("CUBLAS_LT_WORKSPACE_SIZE out of range,", + TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,", " using default workspace size of ", workspace_size, " bytes."); } } @@ -346,7 +357,13 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, - (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + (int)num_batches, +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 + CUBLAS_COMPUTE_32F, +#else + CUDA_R_32F, +#endif + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } template <> @@ -536,12 +553,66 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { c, CUDA_R_16BF, ldc, +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 + CUBLAS_COMPUTE_32F, +#else CUDA_R_32F, +#endif CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); } -#if !defined(USE_ROCM) && !defined(_MSC_VER) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) + +#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000 +// only for rocm 5.7 where we first supported hipblaslt, it was difficult +// to hipify correctly without this change. +#define hipDataType hipblasDatatype_t +#endif + +// hipblaslt custom types were a temporary work-around +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_DATA_TYPE +hipblasltDatatype_t hipToLt(hipDataType type) { + switch (type) { + case HIP_R_32F: return HIPBLASLT_R_32F; + case HIP_R_64F: return HIPBLASLT_R_64F; + case HIP_R_16F: return HIPBLASLT_R_16F; + case HIP_R_8I: return HIPBLASLT_R_8I; + case HIP_C_32F: return HIPBLASLT_C_32F; + case HIP_C_64F: return HIPBLASLT_C_64F; + case HIP_C_16F: return HIPBLASLT_C_16F; + case HIP_C_8I: return HIPBLASLT_C_8I; + case HIP_R_8U: return HIPBLASLT_R_8U; + case HIP_C_8U: return HIPBLASLT_C_8U; + case HIP_R_32I: return HIPBLASLT_R_32I; + case HIP_C_32I: return HIPBLASLT_C_32I; + case HIP_R_32U: return HIPBLASLT_R_32U; + case HIP_C_32U: return HIPBLASLT_C_32U; + case HIP_R_16BF: return HIPBLASLT_R_16B; + case HIP_C_16BF: return HIPBLASLT_C_16B; + default: TORCH_CHECK(false); + } +} +#define HIPTOLT(type) hipToLt(type) +#else +#define HIPTOLT(type) type +#endif + +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_COMPUTE_TYPE +hipblasLtComputeType_t hipblasToLt(hipblasComputeType_t type) { + switch (type) { + case HIPBLAS_COMPUTE_32F: return HIPBLASLT_COMPUTE_F32; + case HIPBLAS_COMPUTE_32F_FAST_16F: return HIPBLASLT_COMPUTE_F32_FAST_F16; + case HIPBLAS_COMPUTE_32F_FAST_TF32: return HIPBLASLT_COMPUTE_F32_FAST_XF32; + case HIPBLAS_COMPUTE_64F: return HIPBLASLT_COMPUTE_F64; + case HIPBLAS_COMPUTE_32I: return HIPBLASLT_COMPUTE_I32; + default: TORCH_CHECK(false); + } +} +#define HIPCOMPTOLT(type) hipblasToLt(type) +#else +#define HIPCOMPTOLT(type) type +#endif namespace { // Following the pattern of CuSparseDescriptor @@ -580,7 +651,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor< cudaDataType_t scale_type) { cublasLtMatmulDesc_t raw_descriptor = nullptr; TORCH_CUDABLAS_CHECK( - cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + cublasLtMatmulDescCreate(&raw_descriptor, HIPCOMPTOLT(compute_type), HIPTOLT(scale_type))); descriptor_.reset(raw_descriptor); } template @@ -601,7 +672,7 @@ class CuBlasLtMatrixLayout : public CuBlasLtDescriptor< bool t = false) { cublasLtMatrixLayout_t raw_descriptor = nullptr; TORCH_CUDABLAS_CHECK( - cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld)); + cublasLtMatrixLayoutCreate(&raw_descriptor, HIPTOLT(type), t ? cols : rows, t ? rows : cols, ld)); descriptor_.reset(raw_descriptor); } }; @@ -645,13 +716,19 @@ void gemm_and_bias( cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cudaDataType_t scaleType = CUDA_R_32F; if constexpr (std::is_same_v) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) abcType = CUDA_R_64F; computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; +#else + TORCH_CHECK(false, "gemm_and_bias is only supported for double type on ROCm 6.0 and above"); +#endif } else if constexpr (std::is_same_v) { +#ifndef USE_ROCM if (at::globalContext().allowTF32CuBLAS()) { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } +#endif abcType = CUDA_R_32F; } else if constexpr (std::is_same_v) { abcType = CUDA_R_16F; @@ -668,7 +745,7 @@ void gemm_and_bias( if (activation == GEMMAndBiasActivationEpilogue::RELU) { epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; } else if (activation == GEMMAndBiasActivationEpilogue::GELU) { -#if CUDA_VERSION >= 11040 +#if CUDA_VERSION >= 11040 || defined(USE_ROCM) epilogue = CUBLASLT_EPILOGUE_GELU_BIAS; #endif } @@ -685,6 +762,7 @@ void gemm_and_bias( size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); +#ifndef USE_ROCM uint32_t a_alignment = _getAlignment(reinterpret_cast(mat1_ptr)); uint32_t b_alignment = _getAlignment(reinterpret_cast(mat2_ptr)); uint32_t c_alignment = _getAlignment(reinterpret_cast(result_ptr)); @@ -693,6 +771,7 @@ void gemm_and_bias( preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment); preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment); preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment); +#endif auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); auto workspace = allocator.allocate(workspaceSize); @@ -952,6 +1031,7 @@ void int8_gemm( int64_t mat2_ld, int32_t* result_ptr, int64_t result_ld) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; cudaDataType_t scaleType = CUDA_R_32I; @@ -1022,11 +1102,14 @@ void int8_gemm( computeType, " scaleType ", scaleType); +#else + TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above"); +#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) } -#endif // !defined(USE_ROCM) && !defined(_MSC_VER) +#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) // ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not. -#if defined(USE_ROCM) && ROCM_VERSION <= 56000 +#if defined(USE_ROCM) && ROCM_VERSION <= 50600 #define ROCM_CONST_BUG #else #define ROCM_CONST_BUG const diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 52acf9abb0dee..ee3b41b4376a9 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -62,7 +62,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); -#if !defined(USE_ROCM) && !defined(_MSC_VER) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) enum GEMMAndBiasActivationEpilogue { None, RELU, @@ -149,7 +149,7 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)); template <> void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); -#if defined(USE_ROCM) && ROCM_VERSION <= 55000 +#if defined(USE_ROCM) && ROCM_VERSION <= 50500 // ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not. #define CUDABLAS_TRSM_ARGTYPES(Dtype) \ hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \ diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index e5163c339da99..35a247725a3ea 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -153,7 +153,7 @@ enum class Activation { GELU, }; -#if !defined(USE_ROCM) && !defined(_MSC_VER) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) { switch (a) { case Activation::None: @@ -171,12 +171,40 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa static bool getDisableAddmmCudaLt() { static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); +#ifdef USE_ROCM + // allow both CUDA and HIP env var names for ROCm builds + // also, current default for ROCm builds is disable by default + if (env_value == nullptr) { + env_value = std::getenv("DISABLE_ADDMM_HIP_LT"); + } + if (env_value != nullptr && strcmp(env_value, "0") == 0) { + return false; + } + return true; +#else if (env_value != nullptr && strcmp(env_value, "1") == 0) { return true; } return false; +#endif } +#ifdef USE_ROCM +static bool isSupportedHipLtROCmArch(int index) { + hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); + std::string device_arch = prop->gcnArchName; + static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; + for (std::string arch : archs) { + size_t substring = device_arch.find(arch); + if (substring != std::string::npos) { + return true; + } + } + TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!"); + return false; +} +#endif + Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { // Make sure to keep addmm_cuda below in sync with this code; it // preflights a check to try to avoid actually needing to call @@ -198,7 +226,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma at::ScalarType scalar_type = self.scalar_type(); c10::MaybeOwned self_; if (&result != &self) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700 // Strangely, if mat2 has only 1 row or column, we get // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] @@ -211,10 +239,17 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous() && result.is_contiguous() && +#ifdef USE_ROCM + isSupportedHipLtROCmArch(self.device().index()) && + (scalar_type == at::ScalarType::Float || + scalar_type == at::ScalarType::Half || + scalar_type == at::ScalarType::BFloat16) && +#else (scalar_type == at::ScalarType::Double || scalar_type == at::ScalarType::Float || scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && +#endif mat2_sizes[0] > 1 && mat2_sizes[1] > 1 && mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 && mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && @@ -234,6 +269,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma } self__sizes = self_->sizes(); } else { +#if defined(USE_ROCM) && ROCM_VERSION >= 50700 + useLtInterface = !disable_addmm_cuda_lt && + result.dim() == 2 && result.is_contiguous() && + isSupportedHipLtROCmArch(self.device().index()) && + (scalar_type == at::ScalarType::Float || + scalar_type == at::ScalarType::Half || + scalar_type == at::ScalarType::BFloat16); +#endif self_ = c10::MaybeOwned::borrowed(self); self__sizes = self_->sizes(); TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); @@ -277,7 +320,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); -#if !defined(USE_ROCM) && !defined(_MSC_VER) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) if (useLtInterface) { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, @@ -299,7 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma self.const_data_ptr(), args.result->data_ptr(), args.result_ld, -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM) activation_to_gemm_and_blas_arg(activation) #else // GELU is not supported (and does not compile!) prior @@ -357,7 +400,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma // gating activation_to_gemm_and_blas_arg above; here we are manually // performing a post-GELU because we weren't able to use the GELU // epilogue above. -#if !defined(CUDA_VERSION) || CUDA_VERSION < 11080 +#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 11080) && !defined(USE_ROCM) if (useLtInterface && activation == Activation::GELU) { at::gelu_(const_cast(*args.result), "tanh"); } diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index f04e41709b51d..94d214cd2588d 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1253,6 +1253,15 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN) list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP) list(APPEND HIP_CXX_FLAGS -std=c++17) + if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0") + list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2) + endif() + if(HIPBLASLT_CUSTOM_DATA_TYPE) + list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE) + endif() + if(HIPBLASLT_CUSTOM_COMPUTE_TYPE) + list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE) + endif() add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT}) add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION}) message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines") @@ -1278,6 +1287,9 @@ if(USE_ROCM) set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${PYTORCH_HIP_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB}) + if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES}) + endif() list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 6989f57f7090b..f7344cc310842 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -136,6 +136,7 @@ if(HIP_FOUND) set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand) set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas) set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas) + set(hipblaslt_DIR ${ROCM_PATH}/lib/cmake/hipblaslt) set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen) set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft) set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft) @@ -154,6 +155,9 @@ if(HIP_FOUND) find_package_and_print_version(hiprand REQUIRED) find_package_and_print_version(rocblas REQUIRED) find_package_and_print_version(hipblas REQUIRED) + if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") + find_package_and_print_version(hipblaslt REQUIRED) + endif() find_package_and_print_version(miopen REQUIRED) if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0") find_package_and_print_version(hipfft REQUIRED) @@ -187,4 +191,57 @@ if(HIP_FOUND) find_library(ROCM_HIPRTC_LIB amdhip64 HINTS ${ROCM_PATH}/lib) # roctx is part of roctracer find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) + + if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") + # check whether hipblaslt is using its own datatype + set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc") + file(WRITE ${file} "" + "#include \n" + "int main() {\n" + " hipblasltDatatype_t bar = HIPBLASLT_R_16F;\n" + " return 0;\n" + "}\n" + ) + + try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output) + + if(hipblaslt_compile_result) + set(HIPBLASLT_CUSTOM_DATA_TYPE ON) + #message("hipblaslt is using custom data type: ${hipblaslt_compile_output}") + message("hipblaslt is using custom data type") + else() + set(HIPBLASLT_CUSTOM_DATA_TYPE OFF) + #message("hipblaslt is NOT using custom data type: ${hipblaslt_compile_output}") + message("hipblaslt is NOT using custom data type") + endif() + + # check whether hipblaslt is using its own compute type + set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_compute_type.cc") + file(WRITE ${file} "" + "#include \n" + "int main() {\n" + " hipblasLtComputeType_t baz = HIPBLASLT_COMPUTE_F32;\n" + " return 0;\n" + "}\n" + ) + + try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output) + + if(hipblaslt_compile_result) + set(HIPBLASLT_CUSTOM_COMPUTE_TYPE ON) + #message("hipblaslt is using custom compute type: ${hipblaslt_compile_output}") + message("hipblaslt is using custom compute type") + else() + set(HIPBLASLT_CUSTOM_COMPUTE_TYPE OFF) + #message("hipblaslt is NOT using custom compute type: ${hipblaslt_compile_output}") + message("hipblaslt is NOT using custom compute type") + endif() + endif() + endif() diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index d2b2e169025c7..b80e22aeda72a 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -237,6 +237,9 @@ def _join_rocm_home(*paths) -> str: '-DUSE_ROCM=1', ] +if ROCM_VERSION is not None and ROCM_VERSION >= (6, 0): + COMMON_HIP_FLAGS.append('-DHIPBLAS_V2') + COMMON_HIPCC_FLAGS = [ '-DCUDA_HAS_FP16=1', '-D__HIP_NO_HALF_OPERATORS__=1', diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index d2084620373b5..e906d90721422 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -611,6 +611,7 @@ ("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)), ("cublas.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), ("cublas_v2.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), + ("cublasLt.h", ("hipblaslt/hipblaslt.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), ("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)), ("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), @@ -3851,7 +3852,7 @@ HIP_UNSUPPORTED, ), ), - ("cudaDataType_t", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaDataType_t", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), ("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), ("CUDA_R_16BF", ("HIP_R_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), ("CUDA_C_16BF", ("HIP_C_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), @@ -7271,6 +7272,65 @@ "cublasDrotmg_v2", ("hipblasDrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), ), + ( + "cublasComputeType_t", + ("hipblasComputeType_t" if rocm_version >= (6, 0, 0) else "hipblasLtComputeType_t", + CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32I", + ("HIPBLAS_COMPUTE_32I" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_I32", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32F", + ("HIPBLAS_COMPUTE_32F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F32", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_64F", + ("HIPBLAS_COMPUTE_64F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F64", CONV_MATH_FUNC, API_BLAS) + ), + ("cublasLtEpilogue_t", ("hipblasLtEpilogue_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_DEFAULT", ("HIPBLASLT_EPILOGUE_DEFAULT", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU", ("HIPBLASLT_EPILOGUE_RELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_BIAS", ("HIPBLASLT_EPILOGUE_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU_BIAS", ("HIPBLASLT_EPILOGUE_RELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU", ("HIPBLASLT_EPILOGUE_GELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU_BIAS", ("HIPBLASLT_EPILOGUE_GELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtHandle_t", ("hipblasLtHandle_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDesc_t", ("hipblasLtMatmulDesc_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescOpaque_t", ("hipblasLtMatmulDescOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescAttributes_t", ("hipblasLtMatmulDescAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSA", ("HIPBLASLT_MATMUL_DESC_TRANSA", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSB", ("HIPBLASLT_MATMUL_DESC_TRANSB", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_EPILOGUE", ("HIPBLASLT_MATMUL_DESC_EPILOGUE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_POINTER", ("HIPBLASLT_MATMUL_DESC_BIAS_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreference_t", ("hipblasLtMatmulPreference_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceOpaque_t", ("hipblasLtMatmulPreferenceOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceAttributes_t", ("hipblasLtMatmulPreferenceAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_SEARCH_MODE", ("HIPBLASLT_MATMUL_PREF_SEARCH_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", ("HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgo_t", ("hipblasLtMatmulAlgo_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulHeuristicResult_t", ("hipblasLtMatmulHeuristicResult_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutCreate", ("hipblasLtMatrixLayoutCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutDestroy", ("hipblasLtMatrixLayoutDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtCreate", ("hipblasLtCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtDestroy", ("hipblasLtDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescCreate", ("hipblasLtMatmulDescCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescDestroy", ("hipblasLtMatmulDescDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescSetAttribute", ("hipblasLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceCreate", ("hipblasLtMatmulPreferenceCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceDestroy", ("hipblasLtMatmulPreferenceDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceSetAttribute", ("hipblasLtMatmulPreferenceSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgoGetHeuristic", ("hipblasLtMatmulAlgoGetHeuristic", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmul", ("hipblasLtMatmul", CONV_MATH_FUNC, API_BLAS)), ( "CURAND_STATUS_SUCCESS", ("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND), @@ -7677,8 +7737,14 @@ HIP_UNSUPPORTED, ), ), - ("cuComplex", ("hipblasComplex", CONV_TYPE, API_BLAS)), - ("cuDoubleComplex", ("hipblasDoubleComplex", CONV_TYPE, API_BLAS)), + ( + "cuComplex", + ("hipComplex" if rocm_version >= (6, 0, 0) else "hipblasComplex", CONV_TYPE, API_BLAS) + ), + ( + "cuDoubleComplex", + ("hipDoubleComplex" if rocm_version >= (6, 0, 0) else "hipblasDoubleComplex", CONV_TYPE, API_BLAS), + ), ("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)), ("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)), ("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)),