From b330df01e1a7be977bcfb1ed720522e1b933326b Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Thu, 28 Sep 2023 00:56:43 -0700 Subject: [PATCH 1/6] Revert "Workaround of SWDEV-407984 (#1254)" This reverts commit e3a6481d69eee1fd2fe35b8f8fdbb35f43e0e0fb. --- aten/src/ATen/cuda/CUDABlas.cpp | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 87ed9d0671f3d..d05eeef7faa21 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1525,24 +1525,12 @@ void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(float)) { -#if ROCM_VERSION >= 50700 - if (batchsize == 0) { - *info = 0; - return ; - } -#endif TORCH_HIPBLAS_CHECK(cublasSgeqrfBatched( handle, m, n, A_array, lda, tau_array, info, batchsize)); } template <> void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double)) { -#if ROCM_VERSION >= 50700 - if (batchsize == 0) { - *info = 0; - return ; - } -#endif TORCH_HIPBLAS_CHECK(cublasDgeqrfBatched( handle, m, n, A_array, lda, tau_array, info, batchsize)); } @@ -1552,12 +1540,6 @@ void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double)) { template <> void geqrfBatched>( HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { -#if ROCM_VERSION >= 50700 - if (batchsize == 0) { - *info = 0; - return ; - } -#endif TORCH_HIPBLAS_CHECK(cublasCgeqrfBatched( handle, m, @@ -1572,12 +1554,6 @@ void geqrfBatched>( template <> void geqrfBatched>( HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { -#if ROCM_VERSION >= 50700 - if (batchsize == 0) { - *info = 0; - return ; - } -#endif TORCH_HIPBLAS_CHECK(cublasZgeqrfBatched( handle, m, From c3210b1176cb444b70521143eab6ff2ae6796788 Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Thu, 28 Sep 2023 00:56:59 -0700 Subject: [PATCH 2/6] Revert "[ROCM] Fix TestLinalgCUDA.test_qr_cuda_complex64." This reverts commit 146e291f9dfa30818643223ec65f5207e2c41bb8. --- aten/src/ATen/cuda/CUDABlas.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d05eeef7faa21..c1f34ada9113e 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1544,9 +1544,9 @@ void geqrfBatched>( handle, m, n, - reinterpret_cast(A_array), + reinterpret_cast(A_array), lda, - reinterpret_cast(tau_array), + reinterpret_cast(tau_array), info, batchsize)); } From df27c6fd1fc5c2552032d5ed83557fc2f384f196 Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Thu, 28 Sep 2023 00:58:47 -0700 Subject: [PATCH 3/6] Revert "Integrate new batched linalg drivers (#1163)" This reverts commit 5cf78070341d9818e9982e1d9d8448997681f034. --- aten/src/ATen/cuda/CUDABlas.cpp | 332 +----------------- aten/src/ATen/cuda/CUDABlas.h | 157 +-------- aten/src/ATen/cuda/Exceptions.h | 14 - .../native/cuda/linalg/BatchLinearAlgebra.cpp | 7 +- .../cuda/linalg/BatchLinearAlgebraLib.h | 4 - .../cuda/linalg/BatchLinearAlgebraLibBlas.cpp | 6 +- test/test_linalg.py | 18 +- torch/utils/hipify/cuda_to_hip_mappings.py | 2 +- 8 files changed, 27 insertions(+), 513 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index c1f34ada9113e..fc3fdd9675314 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1386,77 +1386,6 @@ void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)) { reinterpret_cast(result))); } - -#ifdef USE_ROCM - - -template <> -void getrsBatched(HIPBLAS_GETRS_ARGTYPES(float)) { - TORCH_HIPBLAS_CHECK(cublasSgetrsBatched( - handle, - trans, - n, - nrhs, - dA_array, - lda, - ipiv_array, - dB_array, - ldb, - info_array, - batchsize)); -} - -template <> -void getrsBatched(HIPBLAS_GETRS_ARGTYPES(double)) { - TORCH_HIPBLAS_CHECK(cublasDgetrsBatched( - handle, - trans, - n, - nrhs, - dA_array, - lda, - ipiv_array, - dB_array, - ldb, - info_array, - batchsize)); -} - - -template <> -void getrsBatched>(HIPBLAS_GETRS_ARGTYPES(c10::complex)) { - TORCH_HIPBLAS_CHECK(cublasCgetrsBatched( - handle, - trans, - n, - nrhs, - reinterpret_cast(dA_array), - lda, - ipiv_array, - reinterpret_cast(dB_array), - ldb, - info_array, - batchsize)); -} - -template <> -void getrsBatched>(HIPBLAS_GETRS_ARGTYPES(c10::complex)) { - TORCH_HIPBLAS_CHECK(cublasZgetrsBatched( - handle, - trans, - n, - nrhs, - reinterpret_cast(dA_array), - lda, - ipiv_array, - reinterpret_cast(dB_array), - ldb, - info_array, - batchsize)); -} - -#else - template <> void getrsBatched(CUDABLAS_GETRS_ARGTYPES(float)) { TORCH_CUDABLAS_CHECK(cublasSgetrsBatched( @@ -1520,52 +1449,7 @@ void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex -void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(float)) { - TORCH_HIPBLAS_CHECK(cublasSgeqrfBatched( - handle, m, n, A_array, lda, tau_array, info, batchsize)); -} - -template <> -void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double)) { - TORCH_HIPBLAS_CHECK(cublasDgeqrfBatched( - handle, m, n, A_array, lda, tau_array, info, batchsize)); -} - - - -template <> -void geqrfBatched>( - HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { - TORCH_HIPBLAS_CHECK(cublasCgeqrfBatched( - handle, - m, - n, - reinterpret_cast(A_array), - lda, - reinterpret_cast(tau_array), - info, - batchsize)); -} - -template <> -void geqrfBatched>( - HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { - TORCH_HIPBLAS_CHECK(cublasZgeqrfBatched( - handle, - m, - n, - reinterpret_cast(A_array), - lda, - reinterpret_cast(tau_array), - info, - batchsize)); -} -#else template <> void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)) { TORCH_CUDABLAS_CHECK(cublasSgeqrfBatched( @@ -1605,59 +1489,10 @@ void geqrfBatched>( info, batchsize)); } -#endif - -#ifdef USE_ROCM - -template <> -void getrfBatched( - HIPBLAS_GETRF_BATCHED_ARGTYPES(double)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasDgetrfBatched( - handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); -} - -template <> -void getrfBatched( - HIPBLAS_GETRF_BATCHED_ARGTYPES(float)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasSgetrfBatched( - handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); -} - -template <> -void getrfBatched>( - HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasZgetrfBatched( - handle, - n, - reinterpret_cast(dA_array), - ldda, - ipiv_array, - info_array, - batchsize)); -} -template <> -void getrfBatched>( - HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasCgetrfBatched( - handle, - n, - reinterpret_cast(dA_array), - ldda, - ipiv_array, - info_array, - batchsize)); -} - - -#else template <> void getrfBatched( - CUDABLAS_GETRF_BATCHED_ARGTYPES(double)) { + int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) { auto handle = at::cuda::getCurrentCUDABlasHandle(); TORCH_CUDABLAS_CHECK(cublasDgetrfBatched( handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); @@ -1665,7 +1500,7 @@ void getrfBatched( template <> void getrfBatched( - CUDABLAS_GETRF_BATCHED_ARGTYPES(float)) { + int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) { auto handle = at::cuda::getCurrentCUDABlasHandle(); TORCH_CUDABLAS_CHECK(cublasSgetrfBatched( handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); @@ -1673,7 +1508,12 @@ void getrfBatched( template <> void getrfBatched>( - CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex)) { + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize) { auto handle = at::cuda::getCurrentCUDABlasHandle(); TORCH_CUDABLAS_CHECK(cublasZgetrfBatched( handle, @@ -1687,7 +1527,12 @@ void getrfBatched>( template <> void getrfBatched>( - CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex)) { + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize) { auto handle = at::cuda::getCurrentCUDABlasHandle(); TORCH_CUDABLAS_CHECK(cublasCgetrfBatched( handle, @@ -1698,152 +1543,7 @@ void getrfBatched>( info_array, batchsize)); } -#endif - -#ifdef USE_ROCM -template <> -void getriBatched( - HIPBLAS_GETRI_BATCHED_ARGTYPES(double)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasDgetriBatched( - handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); -} - -template <> -void getriBatched( - HIPBLAS_GETRI_BATCHED_ARGTYPES(float)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasSgetriBatched( - handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); -} - -template <> -void getriBatched>( - HIPBLAS_GETRI_BATCHED_ARGTYPES(c10::complex)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasZgetriBatched( - handle, - n, - reinterpret_cast(dA_array), - ldda, - ipiv_array, - reinterpret_cast(dC_array), - lddc, - info_array, - batchsize)); -} - -template <> -void getriBatched>( - HIPBLAS_GETRI_BATCHED_ARGTYPES(c10::complex)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_HIPBLAS_CHECK(cublasCgetriBatched( - handle, - n, - reinterpret_cast(dA_array), - ldda, - ipiv_array, - reinterpret_cast(dC_array), - lddc, - info_array, - batchsize)); -} - -#else -template <> -void getriBatched( - CUDABLAS_GETRI_BATCHED_ARGTYPES(double)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_CUDABLAS_CHECK(cublasDgetriBatched( - handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); -} - -template <> -void getriBatched( - CUDABLAS_GETRI_BATCHED_ARGTYPES(float)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_CUDABLAS_CHECK(cublasSgetriBatched( - handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); -} - -template <> -void getriBatched>( - CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_CUDABLAS_CHECK(cublasZgetriBatched( - handle, - n, - reinterpret_cast(dA_array), - ldda, - ipiv_array, - reinterpret_cast(dC_array), - lddc, - info_array, - batchsize)); -} - -template <> -void getriBatched>( - CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex)) { - auto handle = at::cuda::getCurrentCUDABlasHandle(); - TORCH_CUDABLAS_CHECK(cublasCgetriBatched( - handle, - n, - reinterpret_cast(dA_array), - ldda, - ipiv_array, - reinterpret_cast(dC_array), - lddc, - info_array, - batchsize)); -} -#endif - -#if defined(USE_ROCM) && (ROCM_VERSION >= 50400) - -template <> -void gelsBatched(HIPBLAS_GELS_BATCHED_ARGTYPES(double)) { - TORCH_HIPBLAS_CHECK(hipblasDgelsBatched( - handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize)); -} -template <> -void gelsBatched(HIPBLAS_GELS_BATCHED_ARGTYPES(float)) { - TORCH_HIPBLAS_CHECK(hipblasSgelsBatched( - handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize)); -} - -template <> -void gelsBatched>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex)) { - TORCH_HIPBLAS_CHECK(hipblasZgelsBatched( - handle, trans, - m, n, nrhs, - reinterpret_cast(dA_array), - ldda, - reinterpret_cast(dC_array), - lddc, - info, - devInfoArray, - batchSize)); -} - -template <> -void gelsBatched>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex)) { - TORCH_HIPBLAS_CHECK(hipblasCgelsBatched( - handle, trans, - m, n, nrhs, - reinterpret_cast(dA_array), - ldda, - reinterpret_cast(dC_array), - lddc, - info, - devInfoArray, - batchSize)); -} - -#else - -#ifdef CUDART_VERSION template <> void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(double)) { @@ -1885,8 +1585,4 @@ void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::comple batchSize)); } -#endif //CUDART_VERSION -#endif //USE_ROCM - } // namespace at::cuda::blas - diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index c74cf71273a92..54052c0883dbc 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -16,12 +16,6 @@ #include #include -#ifdef USE_ROCM -#include -#include -#endif - - namespace at::cuda::blas { // RAII guard that sets the CuBLAS pointer mode and restores it to @@ -262,31 +256,6 @@ void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); template <> void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); -#ifdef USE_ROCM - - -#define HIPBLAS_GETRS_ARGTYPES(Dtype) \ - hipblasHandle_t handle, hipblasOperation_t trans, \ - int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \ - Dtype** dB_array, int ldb, int* info_array, int batchsize - -template -void getrsBatched(HIPBLAS_GETRS_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ", - typeid(Dtype).name()); -} -template<> -TORCH_CUDA_CU_API void getrsBatched(HIPBLAS_GETRS_ARGTYPES(float)); -template<> -TORCH_CUDA_CU_API void getrsBatched(HIPBLAS_GETRS_ARGTYPES(double)); -template<> -TORCH_CUDA_CU_API void getrsBatched>(HIPBLAS_GETRS_ARGTYPES(c10::complex)); -template<> -TORCH_CUDA_CU_API void getrsBatched>(HIPBLAS_GETRS_ARGTYPES(c10::complex)); - - -#else - #define CUDABLAS_GETRS_ARGTYPES(Dtype) \ cublasHandle_t handle, cublasOperation_t trans, \ int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \ @@ -306,31 +275,6 @@ TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES template<> TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex)); -#endif - -#ifdef USE_ROCM -#define HIPBLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \ - hipblasHandle_t handle, int m, int n, Dtype **A_array, int lda, \ - Dtype **tau_array, int *info, int batchsize - -template -void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::blas::geqrfBatched: not implemented for ", - typeid(Dtype).name()); -} -template <> -TORCH_CUDA_CU_API void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(float)); -template <> -TORCH_CUDA_CU_API void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double)); -template <> -TORCH_CUDA_CU_API void geqrfBatched>( - HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)); -template <> -TORCH_CUDA_CU_API void geqrfBatched>( - HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)); -#else #define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \ cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \ Dtype **tau_array, int *info, int batchsize @@ -352,107 +296,22 @@ TORCH_CUDA_CU_API void geqrfBatched>( template <> TORCH_CUDA_CU_API void geqrfBatched>( CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)); -#endif -#ifdef USE_ROCM -#define HIPBLAS_GETRF_BATCHED_ARGTYPES(Dtype) \ - int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize -template -void getrfBatched(HIPBLAS_GETRF_BATCHED_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name()); -} -template<> -TORCH_CUDA_CU_API void getrfBatched(HIPBLAS_GETRF_BATCHED_ARGTYPES(float)); -template<> -TORCH_CUDA_CU_API void getrfBatched(HIPBLAS_GETRF_BATCHED_ARGTYPES(double)); -template<> -TORCH_CUDA_CU_API void getrfBatched>(HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex)); -template<> -TORCH_CUDA_CU_API void getrfBatched>(HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex)); - -#else - -#define CUDABLAS_GETRF_BATCHED_ARGTYPES(Dtype) \ +#define CUDABLAS_GETRF_ARGTYPES(Dtype) \ int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize template -void getrfBatched(CUDABLAS_GETRF_BATCHED_ARGTYPES(Dtype)) { +void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name()); } template<> -TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_BATCHED_ARGTYPES(float)); +TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float)); template<> -TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_BATCHED_ARGTYPES(double)); +TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(double)); template<> -TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex)); +TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); template<> -TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex)); -#endif - - -#ifdef USE_ROCM -#define HIPBLAS_GETRI_BATCHED_ARGTYPES(Dtype) \ - int n, Dtype** dA_array, int ldda, int* ipiv_array, Dtype** dC_array, int lddc, int* info_array, int batchsize - -template -void getriBatched(HIPBLAS_GETRI_BATCHED_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::blas::getriBatched: not implemented for ", typeid(Dtype).name()); -} -template<> -TORCH_CUDA_CU_API void getriBatched(HIPBLAS_GETRI_BATCHED_ARGTYPES(float)); -template<> -TORCH_CUDA_CU_API void getriBatched(HIPBLAS_GETRI_BATCHED_ARGTYPES(double)); -template<> -TORCH_CUDA_CU_API void getriBatched>(HIPBLAS_GETRI_BATCHED_ARGTYPES(c10::complex)); -template<> -TORCH_CUDA_CU_API void getriBatched>(HIPBLAS_GETRI_BATCHED_ARGTYPES(c10::complex)); - - -#else - - -#define CUDABLAS_GETRI_BATCHED_ARGTYPES(Dtype) \ - int n, Dtype** dA_array, int ldda, int* ipiv_array, Dtype** dC_array, int lddc, int* info_array, int batchsize - -template -void getriBatched(CUDABLAS_GETRI_BATCHED_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::blas::getriBatched: not implemented for ", typeid(Dtype).name()); -} -template<> -TORCH_CUDA_CU_API void getriBatched(CUDABLAS_GETRI_BATCHED_ARGTYPES(float)); -template<> -TORCH_CUDA_CU_API void getriBatched(CUDABLAS_GETRI_BATCHED_ARGTYPES(double)); -template<> -TORCH_CUDA_CU_API void getriBatched>(CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex)); -template<> -TORCH_CUDA_CU_API void getriBatched>(CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex)); - -#endif - - - -#if defined(USE_ROCM) && (ROCM_VERSION >= 50400) - -#define HIPBLAS_GELS_BATCHED_ARGTYPES(Dtype) \ - hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize - -template -void gelsBatched(HIPBLAS_GELS_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name()); -} - -template<> -TORCH_CUDA_CU_API void gelsBatched(HIPBLAS_GELS_BATCHED_ARGTYPES(double)); -template<> -TORCH_CUDA_CU_API void gelsBatched(HIPBLAS_GELS_BATCHED_ARGTYPES(float)); -template<> -TORCH_CUDA_CU_API void gelsBatched>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex)); -template<> -TORCH_CUDA_CU_API void gelsBatched>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex)); - -#else - -#ifdef CUDART_VERSION +TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); #define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \ cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize @@ -471,8 +330,4 @@ TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_A template<> TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex)); -#endif //CUDART_VERSION -#endif //USE_ROCM - - } // namespace at::cuda::blas diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index b0b0c20f6a1eb..a15f0d7947ec2 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -12,9 +12,6 @@ #include #include -#ifdef USE_ROCM -#include -#endif namespace c10 { @@ -56,17 +53,6 @@ C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); " when calling `" #EXPR "`"); \ } while (0) -#ifdef USE_ROCM -#define TORCH_HIPBLAS_CHECK(EXPR) \ - do { \ - hipblasStatus_t __err = EXPR; \ - TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \ - "CUDA error: ", \ - " when calling `" #EXPR "`"); \ - } while (0) -#endif - - const char *cusparseGetErrorString(cusparseStatus_t status); #define TORCH_CUDASPARSE_CHECK(EXPR) \ diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 3351caec14b2b..a08547dc21b6a 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -40,8 +40,6 @@ #include #include - - const bool use_magma_ = true; namespace { @@ -2675,6 +2673,7 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/ /*unitriangular=*/false); } else { // underdetermined case Tensor Ah = cloneBatchedColumnMajor(A.mH()); + // Step 1: compute QR factorization of conjugate transpose of A using geqrf geqrf_kernel(Ah, tau); @@ -2745,7 +2744,7 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul "Please rebuild with cuSOLVER."); #endif } else { // m >= n -//#if !AT_ROCM_ENABLED() +#if !AT_ROCM_ENABLED() // On CUDA platform we use either cuBLAS or cuSOLVER here // the batched vs looped dispatch is implemented based on the following performance results // https://github.com/pytorch/pytorch/pull/54725#issuecomment-832234456 @@ -2754,13 +2753,11 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul } else { gels_looped(a, b, infos); } -/* #else // On ROCm platform we can only use MAGMA here // If MAGMA is not available, an error will be thrown gels_magma(a, b, infos); #endif // !AT_ROCM_ENABLED() -*/ } } diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h index 036a98c54ad70..2cf2004e8e724 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h @@ -8,10 +8,6 @@ #include #include -#ifdef USE_ROCM -#include -#endif - #if (defined(CUDART_VERSION) && defined(CUSOLVER_VERSION)) || (defined(USE_ROCM) && ROCM_VERSION >= 50300) #define USE_LINALG_SOLVER #endif diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp index b07eb21b4f332..9f7c0c02663b9 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp @@ -131,14 +131,10 @@ template static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) { TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same"); TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same"); -#ifdef USE_ROCM - const auto trans = (hipblasOperation_t)to_cublas(transpose); -#else const auto trans = to_cublas(transpose); -#endif auto pivots_data = pivots.data_ptr(); - auto batch_size = cuda_int_cast(batchCount(LU), "batch_size"); + auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");; auto m = cuda_int_cast(LU.size(-2), "m"); auto nrhs = cuda_int_cast(B.size(-1), "nrhs"); auto lda = cuda_int_cast(std::max(1, m), "lda"); diff --git a/test/test_linalg.py b/test/test_linalg.py index a0307016e5b76..3bbc537b9c595 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -18,7 +18,7 @@ TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices, make_fullrank_matrices_with_distinct_singular_values, freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo, - setLinalgBackendsToDefaultFinally, skipIfRocm) + setLinalgBackendsToDefaultFinally) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, @@ -3637,13 +3637,10 @@ def test_linalg_qr_autograd_errors(self, device, dtype): "The QR decomposition is not differentiable when mode='complete' and nrows > ncols"): b.backward() + @skipCUDAIfNoCusolver @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - @dtypesIfCUDA(*floating_types_and( - *[torch.cfloat] if not TEST_WITH_ROCM else [], - *[torch.cdouble] if not TEST_WITH_ROCM else [])) def test_qr_batched(self, device, dtype): - torch.backends.cuda.preferred_linalg_library("cusolver") """ test torch.linalg.qr vs numpy.linalg.qr. We need some special logic because numpy does not support batched qr @@ -4501,9 +4498,7 @@ def renorm(matrix, value, dim, max_norm): self.assertEqual(m3.norm(2, 0), m2.norm(2, 0)) @skipCPUIfNoLapack - @dtypesIfCUDA(*floating_types_and( - *[torch.cfloat] if not TEST_WITH_ROCM else [], - *[torch.cdouble] if not TEST_WITH_ROCM else [])) + @skipCUDAIfNoCusolver @dtypes(*floating_and_complex_types()) def test_ormqr(self, device, dtype): @@ -4760,9 +4755,6 @@ def test_renorm_ps(self, device): @skipCPUIfNoLapack @skipCUDAIfNoCusolver - @dtypesIfCUDA(*floating_types_and( - *[torch.cfloat] if not TEST_WITH_ROCM else [], - *[torch.cdouble] if not TEST_WITH_ROCM else [])) @dtypes(*floating_and_complex_types()) def test_householder_product(self, device, dtype): def generate_reflectors_and_tau(A): @@ -7155,7 +7147,6 @@ def sub_test(pivot): @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-8, torch.complex128: 1e-8}) def test_lu_solve_batched(self, device, dtype): - torch.backends.cuda.preferred_linalg_library('cusolver') def sub_test(pivot): def lu_solve_batch_test_helper(A_dims, b_dims, pivot): b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype) @@ -7331,9 +7322,6 @@ def test_nuclear_norm_out(self, device, dtype): @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack - @dtypesIfCUDA(*floating_types_and( - *[torch.cfloat] if not TEST_WITH_ROCM else [], - *[torch.cdouble] if not TEST_WITH_ROCM else [])) @dtypes(*floating_and_complex_types()) def test_geqrf(self, device, dtype): diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index c09dd98dc5676..984f3148e9293 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -4,7 +4,7 @@ import subprocess from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT, - API_PYTORCH, API_PYT_EXT, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, + API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, API_SPECIAL, API_ROCMSMI, CONV_CACHE, CONV_CONTEXT, CONV_D3D9, CONV_D3D10, CONV_D3D11, CONV_DEF, CONV_DEVICE, CONV_DEVICE_FUNC, CONV_EGL, CONV_ERROR, CONV_EVENT, From 9a98f94fa51d3ee3b2909cb7673de45ec596d5eb Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Thu, 28 Sep 2023 01:00:54 -0700 Subject: [PATCH 4/6] Updated changes for SWDEV-407984 --- aten/src/ATen/cuda/CUDABlas.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index fc3fdd9675314..59139bbdcebb8 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1452,12 +1452,24 @@ void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)) { +#if ROCM_VERSION >= 50700 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_CUDABLAS_CHECK(cublasSgeqrfBatched( handle, m, n, A_array, lda, tau_array, info, batchsize)); } template <> void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)) { +#if ROCM_VERSION >= 50700 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_CUDABLAS_CHECK(cublasDgeqrfBatched( handle, m, n, A_array, lda, tau_array, info, batchsize)); } @@ -1465,6 +1477,12 @@ void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)) { template <> void geqrfBatched>( CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { +#if ROCM_VERSION >= 50700 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_CUDABLAS_CHECK(cublasCgeqrfBatched( handle, m, @@ -1479,6 +1497,12 @@ void geqrfBatched>( template <> void geqrfBatched>( CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { +#if ROCM_VERSION >= 50700 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_CUDABLAS_CHECK(cublasZgeqrfBatched( handle, m, From 21f1c0a75164f6be72f24427632a7f2b63e9ace3 Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Thu, 28 Sep 2023 01:01:57 -0700 Subject: [PATCH 5/6] Update a missing constant in hipify --- torch/utils/hipify/cuda_to_hip_mappings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 984f3148e9293..c09dd98dc5676 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -4,7 +4,7 @@ import subprocess from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT, - API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, + API_PYTORCH, API_PYT_EXT, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, API_SPECIAL, API_ROCMSMI, CONV_CACHE, CONV_CONTEXT, CONV_D3D9, CONV_D3D10, CONV_D3D11, CONV_DEF, CONV_DEVICE, CONV_DEVICE_FUNC, CONV_EGL, CONV_ERROR, CONV_EVENT, From 3edf4622bb3c05bb9543d6c23ed6f4abdb90ba61 Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Thu, 28 Sep 2023 12:15:23 -0700 Subject: [PATCH 6/6] NIT related changes --- aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp index 9f7c0c02663b9..f130c29db9677 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp @@ -134,7 +134,7 @@ static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots const auto trans = to_cublas(transpose); auto pivots_data = pivots.data_ptr(); - auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");; + auto batch_size = cuda_int_cast(batchCount(LU), "batch_size"); auto m = cuda_int_cast(LU.size(-2), "m"); auto nrhs = cuda_int_cast(B.size(-1), "nrhs"); auto lda = cuda_int_cast(std::max(1, m), "lda");