diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 2538c03350a1..8e56d00b90dd 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1457,12 +1457,24 @@ void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(float)) { +#if ROCM_VERSION >= 50700 && ROCM_VERSION < 50800 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_HIPBLAS_CHECK(cublasSgeqrfBatched( handle, m, n, A_array, lda, tau_array, info, batchsize)); } template <> void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double)) { +#if ROCM_VERSION >= 50700 && ROCM_VERSION < 50800 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_HIPBLAS_CHECK(cublasDgeqrfBatched( handle, m, n, A_array, lda, tau_array, info, batchsize)); } @@ -1472,6 +1484,12 @@ void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double)) { template <> void geqrfBatched>( HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { +#if ROCM_VERSION >= 50700 && ROCM_VERSION < 50800 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_HIPBLAS_CHECK(cublasCgeqrfBatched( handle, m, @@ -1486,6 +1504,12 @@ void geqrfBatched>( template <> void geqrfBatched>( HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)) { +#if ROCM_VERSION >= 50700 && ROCM_VERSION < 50800 + if (batchsize == 0) { + *info = 0; + return ; + } +#endif TORCH_HIPBLAS_CHECK(cublasZgeqrfBatched( handle, m,