Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions aten/src/ATen/cuda/CUDASparseDescriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,7 @@ cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch
auto leading_dimension =
is_row_major ? input_strides[ndim - 2] : input_strides[ndim - 1];

#if !defined(USE_ROCM)
auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
#else
TORCH_INTERNAL_ASSERT(is_column_major, "Expected column major input.");
auto order = CUSPARSE_ORDER_COL;
#endif

auto batch_stride = ndim > 2 && batch_offset >= 0 ? input_strides[ndim - 3] : 0;
// NOLINTNEXTLINE(*const-cast)
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ void spmm(

// CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this
// silently returning incorrect results
#if defined(USE_ROCM)
#if defined(USE_ROCM) && (ROCM_VERSION < 60300)
auto mat1_32 = at::native::_sparse_csr_tensor_unsafe(
mat1.crow_indices().to(kInt),
mat1.col_indices().to(kInt),
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ const char* cusparseGetErrorString(cusparseStatus_t status) {
case CUSPARSE_STATUS_ZERO_PIVOT:
return "an entry of the matrix is either structural zero or numerical zero (singular block)";

#if defined(USE_ROCM)
case CUSPARSE_STATUS_NOT_SUPPORTED:
return "operation is not supported";

case CUSPARSE_STATUS_INSUFFICIENT_RESOURCES:
return "Resources are insufficient";
#endif // defined(USE_ROCM)

default:
return "unknown error";
}
Expand Down
51 changes: 43 additions & 8 deletions aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,24 @@
#include <thrust/iterator/discard_iterator.h>


#if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000)
#if defined(__CUDACC__) && ((CUSPARSE_VERSION >= 11000) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
#define IS_CUSPARSE11_AVAILABLE() 1
#else
#define IS_CUSPARSE11_AVAILABLE() 0
#endif

#if defined(USE_ROCM) && (ROCM_VERSION >= 70000)
#define HIPSPARSE_FP16_SUPPORT 1
#else
#define HIPSPARSE_FP16_SUPPORT 0
#endif

#if defined(USE_ROCM) && (ROCM_VERSION >= 70100)
#define HIPSPARSE_FP16_BF16_SUPPORT 1
#else
#define HIPSPARSE_FP16_BF16_SUPPORT 0
#endif

#if IS_CUSPARSE11_AVAILABLE()
#include <library_types.h>
#endif
Expand Down Expand Up @@ -207,13 +219,24 @@ struct CusparseMatrixMultiplyOp {

CusparseMatrixMultiplyOp() {
static_assert(
std::is_same_v<c10::Half, scalar_t> ||
std::is_same_v<c10::BFloat16, scalar_t> ||
#if !defined(USE_ROCM) || HIPSPARSE_FP16_SUPPORT
std::is_same_v<c10::Half, scalar_t> ||
#endif
#if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT
std::is_same_v<c10::BFloat16, scalar_t> ||
#endif
std::is_same_v<float, scalar_t> ||
std::is_same_v<double, scalar_t> ||
std::is_same_v<c10::complex<float>, scalar_t> ||
std::is_same_v<c10::complex<double>, scalar_t>,
"cusparseSpGEMM only supports data type of half, bfloat16, float, double and complex float, double.");
"cusparseSpGEMM only supports data type of "
#if !defined(USE_ROCM) || HIPSPARSE_FP16_SUPPORT
"half, "
#endif
#if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT
"bfloat16, "
#endif
"float, double and complex float, double.");
// SpGEMM Computation
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc));
}
Expand Down Expand Up @@ -268,11 +291,13 @@ struct CusparseMatrixMultiplyOp {

// If a specific GPU model does not provide native support for a given data type,
// the routine returns CUSPARSE_STATUS_ARCH_MISMATCH error
#if !defined(USE_ROCM)
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(prop->major >= 5 && !((10*prop->major + prop->minor) < 53 && computeType == CUDA_R_16F),
"sparse_mm: CUDA Float16 requires compute capability >= 53 (current: ", prop->major, prop->minor, ")");
TORCH_CHECK(!(prop->major < 8 && computeType == CUDA_R_16BF),
"sparse_mm: CUDA BFloat16 requires compute capability >= 80 (current: ", prop->major, prop->minor, ")");
#endif

// ask bufferSize1 bytes for external memory
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
Expand Down Expand Up @@ -811,10 +836,20 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);

#if IS_CUSPARSE11_AVAILABLE()
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#else
#if !defined(USE_ROCM) || HIPSPARSE_FP16_BF16_SUPPORT
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#elif HIPSPARSE_FP16_SUPPORT
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#else
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#endif
#else // not IS_CUSPARSE11_AVAILABLE()
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
Expand Down
15 changes: 10 additions & 5 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
import unittest
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
from torch.testing._internal.common_utils import TestCase, run_tests, do_test_dtypes, \
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \
Expand Down Expand Up @@ -68,6 +68,12 @@ def _op_supports_any_sparse(op):
) or (not IS_WINDOWS and not TEST_WITH_ROCM)

HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
HIPSPARSE_FP16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.0")
HIPSPARSE_BF16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not yet tested, but hipsparse will support this in 7.1


SPARSE_COMPLEX128_SUPPORTED = CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
SPARSE_FLOAT16_SUPPORTED = (SM53OrLater and torch.version.cuda) or (HIPSPARSE_FP16_SUPPORTED)
SPARSE_BFLOAT16_SUPPORTED = (SM80OrLater and torch.version.cuda) or (HIPSPARSE_BF16_SUPPORTED)

def all_sparse_layouts(test_name='layout', include_strided=False):
return parametrize(test_name, [
Expand Down Expand Up @@ -3608,13 +3614,12 @@ def test_log_softmax_zero_nnz(self, device, dtype):
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)

# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
@skipIfRocm
@coalescedonoff
@dtypes(*floating_and_complex_types())
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater else [],
*[torch.bfloat16] if SM80OrLater else [],
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [],
torch.complex64,
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
*[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else []))
@unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
def test_sparse_matmul(self, device, dtype, coalesced):
Expand Down
39 changes: 17 additions & 22 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
suppress_warnings)
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm,
skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
Expand All @@ -26,7 +26,8 @@
all_types_and_complex, floating_and_complex_types_and)
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
from test_sparse import HIPSPARSE_BF16_SUPPORTED, HIPSPARSE_FP16_SUPPORTED, \
SPARSE_FLOAT16_SUPPORTED, SPARSE_BFLOAT16_SUPPORTED, SPARSE_COMPLEX128_SUPPORTED
import operator

if TEST_SCIPY:
Expand Down Expand Up @@ -1545,9 +1546,10 @@ def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device
run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device)

@onlyCUDA
@unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
@skipIfRocmVersionLessThan((6, 3))
@skipCUDAIfNoSparseGeneric
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypes(*floating_and_complex_types_and(*[torch.half] if HIPSPARSE_FP16_SUPPORTED else [],
*[torch.bfloat16] if HIPSPARSE_BF16_SUPPORTED else []))
def test_bmm(self, device, dtype):
def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
b = b.mH if (op_b and a.shape == b.shape) else b
Expand Down Expand Up @@ -1834,7 +1836,7 @@ def run_test(a, b, upper, transpose, unitriangular, op_out):
run_test(a, b, upper, unitriangular, transpose, op_out)

@skipCPUIfNoMklSparse
@unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
@skipIfRocmVersionLessThan((6, 3))
@dtypes(torch.double)
def test_mm(self, device, dtype):
def test_shape(di, dj, dk, nnz0=None, nnz1=None):
Expand Down Expand Up @@ -1954,8 +1956,8 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype):

@dtypes(*floating_and_complex_types())
@dtypesIfCUDA(*floating_and_complex_types_and(
*[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
*[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else []))
@precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2})
def test_sparse_addmm(self, device, dtype):
def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
Expand Down Expand Up @@ -1984,18 +1986,15 @@ def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
test_shape(7, 8, 9, 20, True, index_dtype, (1, 1))

@skipCPUIfNoMklSparse
@skipIfRocmVersionLessThan((6, 3))
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_types_and(torch.complex64,
*[torch.bfloat16] if SM80OrLater else [],
*[torch.half] if SM53OrLater else [],
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [],
*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
*[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else []))
@sparse_compressed_nonblock_layouts()
@skipCUDAIf(
not _check_cusparse_spgemm_available(),
"cuSparse Generic API SpGEMM is not available"
)
def test_addmm_all_sparse_csr(self, device, dtype, layout):
M = torch.randn(10, 25, device=device).to(dtype)
m1 = torch.randn(10, 50, device=device).to(dtype)
Expand Down Expand Up @@ -2066,16 +2065,12 @@ def maybe_transpose(cond, m):
@skipCPUIfNoMklSparse
@dtypes(*floating_and_complex_types())
@dtypesIfCUDA(*floating_types_and(torch.complex64,
*[torch.bfloat16] if SM80OrLater else [],
*[torch.half] if SM53OrLater else [],
*[torch.complex128]
if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
else []))
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [],
*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
*[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else []))
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):
if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0):
self.skipTest("Skipped on ROCm")
M = torch.randn(n, m, device=device).to(dtype)
m1 = torch.randn(n, k, device=device).to(dtype)
m2 = torch.randn(k, m, device=device).to(dtype)
Expand Down
14 changes: 13 additions & 1 deletion torch/utils/hipify/cuda_to_hip_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8175,11 +8175,15 @@
("cusparseSpGEMMDescr_t", ("hipsparseSpGEMMDescr_t", CONV_TYPE, API_SPECIAL)),
("CUSPARSE_INDEX_32I", ("HIPSPARSE_INDEX_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_INDEX_64I", ("HIPSPARSE_INDEX_64I", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COLUMN", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_MV_ALG_DEFAULT", ("HIPSPARSE_MV_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_MM_ALG_DEFAULT", ("HIPSPARSE_MM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_SPMM_COO_ALG1", ("HIPSPARSE_SPMM_COO_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_SPMM_COO_ALG2", ("HIPSPARSE_SPMM_COO_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_SPMM_CSR_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_SPMM_CSR_ALG2", ("HIPSPARSE_SPMM_CSR_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_SPMM_CSR_ALG3", ("HIPSPARSE_SPMM_CSR_ALG3", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
Expand Down Expand Up @@ -8228,6 +8232,14 @@
"CUSPARSE_STATUS_ZERO_PIVOT",
("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL),
),
(
"CUSPARSE_STATUS_NOT_SUPPORTED",
("HIPSPARSE_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_SPECIAL),
),
(
"CUSPARSE_STATUS_INSUFFICIENT_RESOURCES",
("HIPSPARSE_STATUS_INSUFFICIENT_RESOURCES", CONV_NUMERIC_LITERAL, API_SPECIAL),
),
(
"CUSPARSE_OPERATION_TRANSPOSE",
("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL),
Expand Down