diff --git a/aten/src/ATen/cuda/CUDASparseDescriptors.cpp b/aten/src/ATen/cuda/CUDASparseDescriptors.cpp index 84711be2ddf33..092314ac81f21 100644 --- a/aten/src/ATen/cuda/CUDASparseDescriptors.cpp +++ b/aten/src/ATen/cuda/CUDASparseDescriptors.cpp @@ -75,12 +75,7 @@ cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch auto leading_dimension = is_row_major ? input_strides[ndim - 2] : input_strides[ndim - 1]; -#if !defined(USE_ROCM) auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; -#else - TORCH_INTERNAL_ASSERT(is_column_major, "Expected column major input."); - auto order = CUSPARSE_ORDER_COL; -#endif auto batch_stride = ndim > 2 && batch_offset >= 0 ? input_strides[ndim - 3] : 0; // NOLINTNEXTLINE(*const-cast) diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp index e43df8e048e81..6546707b4d320 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp @@ -615,7 +615,7 @@ void spmm( // CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this // silently returning incorrect results -#if defined(USE_ROCM) +#if defined(USE_ROCM) && (ROCM_VERSION < 60300) auto mat1_32 = at::native::_sparse_csr_tensor_unsafe( mat1.crow_indices().to(kInt), mat1.col_indices().to(kInt), diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index 133a73505dcf0..6b414e3f13635 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -63,6 +63,14 @@ const char* cusparseGetErrorString(cusparseStatus_t status) { case CUSPARSE_STATUS_ZERO_PIVOT: return "an entry of the matrix is either structural zero or numerical zero (singular block)"; + #if defined(USE_ROCM) + case CUSPARSE_STATUS_NOT_SUPPORTED: + return "operation is not supported"; + + case CUSPARSE_STATUS_INSUFFICIENT_RESOURCES: + return "Resources are insufficient"; + #endif // defined(USE_ROCM) + default: return "unknown error"; } diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 23fcd26cf412f..89bb1da3f1218 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -40,12 +40,24 @@ #include -#if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000) +#if defined(__CUDACC__) && ((CUSPARSE_VERSION >= 11000) || (defined(USE_ROCM) && ROCM_VERSION >= 60300)) #define IS_CUSPARSE11_AVAILABLE() 1 #else #define IS_CUSPARSE11_AVAILABLE() 0 #endif +#if defined(USE_ROCM) && (ROCM_VERSION >= 70000) +#define HIPSPARSE_FP16_SUPPORT 1 +#else +#define HIPSPARSE_FP16_SUPPORT 0 +#endif + +#if defined(USE_ROCM) && (ROCM_VERSION >= 70100) +#define HIPSPARSE_FP16_BF16_SUPPORT 1 +#else +#define HIPSPARSE_FP16_BF16_SUPPORT 0 +#endif + #if IS_CUSPARSE11_AVAILABLE() #include #endif @@ -207,13 +219,24 @@ struct CusparseMatrixMultiplyOp { CusparseMatrixMultiplyOp() { static_assert( - std::is_same_v || - std::is_same_v || + #if !defined(USE_ROCM) || HIPSPARSE_FP16_SUPPORT + std::is_same_v || + #endif + #if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT + std::is_same_v || + #endif std::is_same_v || std::is_same_v || std::is_same_v, scalar_t> || std::is_same_v, scalar_t>, - "cusparseSpGEMM only supports data type of half, bfloat16, float, double and complex float, double."); + "cusparseSpGEMM only supports data type of " + #if !defined(USE_ROCM) || HIPSPARSE_FP16_SUPPORT + "half, " + #endif + #if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT + "bfloat16, " + #endif + "float, double and complex float, double."); // SpGEMM Computation TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc)); } @@ -268,11 +291,13 @@ struct CusparseMatrixMultiplyOp { // If a specific GPU model does not provide native support for a given data type, // the routine returns CUSPARSE_STATUS_ARCH_MISMATCH error + #if !defined(USE_ROCM) cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK(prop->major >= 5 && !((10*prop->major + prop->minor) < 53 && computeType == CUDA_R_16F), "sparse_mm: CUDA Float16 requires compute capability >= 53 (current: ", prop->major, prop->minor, ")"); TORCH_CHECK(!(prop->major < 8 && computeType == CUDA_R_16BF), "sparse_mm: CUDA BFloat16 requires compute capability >= 80 (current: ", prop->major, prop->minor, ")"); + #endif // ask bufferSize1 bytes for external memory TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation( @@ -811,10 +836,20 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) { output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0); #if IS_CUSPARSE11_AVAILABLE() - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] { - sparse_sparse_matmul_cuda_kernel(output, mat1_.coalesce(), mat2_.coalesce()); - }); -#else + #if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] { + sparse_sparse_matmul_cuda_kernel(output, mat1_.coalesce(), mat2_.coalesce()); + }); + #elif HIPSPARSE_FP16_SUPPORT + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, mat1_.scalar_type(), "sparse_matmul", [&] { + sparse_sparse_matmul_cuda_kernel(output, mat1_.coalesce(), mat2_.coalesce()); + }); + #else + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { + sparse_sparse_matmul_cuda_kernel(output, mat1_.coalesce(), mat2_.coalesce()); + }); + #endif +#else // not IS_CUSPARSE11_AVAILABLE() AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { sparse_sparse_matmul_cuda_kernel(output, mat1_.coalesce(), mat2_.coalesce()); }); diff --git a/test/test_sparse.py b/test/test_sparse.py index 64d7ad9b1c2a7..713ec1bbd2a93 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -8,7 +8,7 @@ import random import unittest from torch.testing import make_tensor -from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ +from torch.testing._internal.common_utils import TestCase, run_tests, do_test_dtypes, \ load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \ parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \ @@ -68,6 +68,12 @@ def _op_supports_any_sparse(op): ) or (not IS_WINDOWS and not TEST_WITH_ROCM) HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0") +HIPSPARSE_FP16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.0") +HIPSPARSE_BF16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.1") + +SPARSE_COMPLEX128_SUPPORTED = CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED +SPARSE_FLOAT16_SUPPORTED = (SM53OrLater and torch.version.cuda) or (HIPSPARSE_FP16_SUPPORTED) +SPARSE_BFLOAT16_SUPPORTED = (SM80OrLater and torch.version.cuda) or (HIPSPARSE_BF16_SUPPORTED) def all_sparse_layouts(test_name='layout', include_strided=False): return parametrize(test_name, [ @@ -3608,13 +3614,12 @@ def test_log_softmax_zero_nnz(self, device, dtype): self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype) # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA - @skipIfRocm @coalescedonoff @dtypes(*floating_and_complex_types()) - @dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater else [], - *[torch.bfloat16] if SM80OrLater else [], + @dtypesIfCUDA(*floating_types_and(*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [], + *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [], torch.complex64, - *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else [])) + *[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else [])) @unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor") @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2}) def test_sparse_matmul(self, device, dtype, coalesced): diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 4ef62c9184d06..a47e05278f6fb 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -12,8 +12,8 @@ from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC from torch.testing._internal.common_utils import \ (TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, - run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU, - suppress_warnings) + run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, + skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings) from torch.testing._internal.common_device_type import \ (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric, precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan, @@ -26,7 +26,8 @@ all_types_and_complex, floating_and_complex_types_and) from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse -from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED +from test_sparse import HIPSPARSE_BF16_SUPPORTED, HIPSPARSE_FP16_SUPPORTED, \ + SPARSE_FLOAT16_SUPPORTED, SPARSE_BFLOAT16_SUPPORTED, SPARSE_COMPLEX128_SUPPORTED import operator if TEST_SCIPY: @@ -1545,9 +1546,10 @@ def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device) @onlyCUDA - @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported") + @skipIfRocmVersionLessThan((6, 3)) @skipCUDAIfNoSparseGeneric - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @dtypes(*floating_and_complex_types_and(*[torch.half] if HIPSPARSE_FP16_SUPPORTED else [], + *[torch.bfloat16] if HIPSPARSE_BF16_SUPPORTED else [])) def test_bmm(self, device, dtype): def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None): b = b.mH if (op_b and a.shape == b.shape) else b @@ -1834,7 +1836,7 @@ def run_test(a, b, upper, transpose, unitriangular, op_out): run_test(a, b, upper, unitriangular, transpose, op_out) @skipCPUIfNoMklSparse - @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported") + @skipIfRocmVersionLessThan((6, 3)) @dtypes(torch.double) def test_mm(self, device, dtype): def test_shape(di, dj, dk, nnz0=None, nnz1=None): @@ -1954,8 +1956,8 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype): @dtypes(*floating_and_complex_types()) @dtypesIfCUDA(*floating_and_complex_types_and( - *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [], - *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else [])) + *[torch.half] if SPARSE_FLOAT16_SUPPORTED else [], + *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [])) @precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2}) def test_sparse_addmm(self, device, dtype): def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None): @@ -1984,18 +1986,15 @@ def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None): test_shape(7, 8, 9, 20, True, index_dtype, (1, 1)) @skipCPUIfNoMklSparse + @skipIfRocmVersionLessThan((6, 3)) @dtypes(*floating_and_complex_types()) @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) @dtypesIfCUDA(*floating_types_and(torch.complex64, - *[torch.bfloat16] if SM80OrLater else [], - *[torch.half] if SM53OrLater else [], - *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else [])) + *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [], + *[torch.half] if SPARSE_FLOAT16_SUPPORTED else [], + *[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else [])) @sparse_compressed_nonblock_layouts() - @skipCUDAIf( - not _check_cusparse_spgemm_available(), - "cuSparse Generic API SpGEMM is not available" - ) def test_addmm_all_sparse_csr(self, device, dtype, layout): M = torch.randn(10, 25, device=device).to(dtype) m1 = torch.randn(10, 50, device=device).to(dtype) @@ -2066,16 +2065,12 @@ def maybe_transpose(cond, m): @skipCPUIfNoMklSparse @dtypes(*floating_and_complex_types()) @dtypesIfCUDA(*floating_types_and(torch.complex64, - *[torch.bfloat16] if SM80OrLater else [], - *[torch.half] if SM53OrLater else [], - *[torch.complex128] - if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED - else [])) + *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [], + *[torch.half] if SPARSE_FLOAT16_SUPPORTED else [], + *[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else [])) @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k): - if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0): - self.skipTest("Skipped on ROCm") M = torch.randn(n, m, device=device).to(dtype) m1 = torch.randn(n, k, device=device).to(dtype) m2 = torch.randn(k, m, device=device).to(dtype) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index bc871451c5556..164bda3a1442a 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -8175,11 +8175,15 @@ ("cusparseSpGEMMDescr_t", ("hipsparseSpGEMMDescr_t", CONV_TYPE, API_SPECIAL)), ("CUSPARSE_INDEX_32I", ("HIPSPARSE_INDEX_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_INDEX_64I", ("HIPSPARSE_INDEX_64I", CONV_NUMERIC_LITERAL, API_SPECIAL)), - ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COLUMN", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_MV_ALG_DEFAULT", ("HIPSPARSE_MV_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_MM_ALG_DEFAULT", ("HIPSPARSE_MM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_SPMM_COO_ALG1", ("HIPSPARSE_SPMM_COO_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_SPMM_COO_ALG2", ("HIPSPARSE_SPMM_COO_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_SPMM_CSR_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG2", ("HIPSPARSE_SPMM_CSR_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG3", ("HIPSPARSE_SPMM_CSR_ALG3", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), @@ -8228,6 +8232,14 @@ "CUSPARSE_STATUS_ZERO_PIVOT", ("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL), ), + ( + "CUSPARSE_STATUS_NOT_SUPPORTED", + ("HIPSPARSE_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_INSUFFICIENT_RESOURCES", + ("HIPSPARSE_STATUS_INSUFFICIENT_RESOURCES", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), ( "CUSPARSE_OPERATION_TRANSPOSE", ("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL),