Skip to content

Commit

Permalink
Sparse CSR CUDA: add addmv_out
Browse files Browse the repository at this point in the history
This PR adds `addmv_out_sparse_csr_cuda`. The operation is used to
compute matrix-vector multiplication. Since structured_delegate is used
we only need to implement the out variant, the in-place and normal
variants are autogenerated.
Working on this PR revealed that float16 (and probably bfloat16) inputs
do not work correctly in cusparse, therefore for this case `addmm` is
used with squeezes and unsqueezes.

ghstack-source-id: 32ca31b4fe27ba827baca43db1456edcbb59436b
Pull Request resolved: pytorch#61407
  • Loading branch information
IvanYashchuk committed Jul 12, 2021
1 parent 732e011 commit 1b2e495
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 7 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Expand Up @@ -367,6 +367,7 @@ filegroup(
"aten/src/ATen/native/miopen/Conv_miopen.cpp",
"aten/src/ATen/native/miopen/RNN_miopen.cpp",
"aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp",
"aten/src/ATen/native/sparse/cuda/SparseBlas.cpp",
"aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp",
"aten/src/THC/THCCachingHostAllocator.cpp",
"aten/src/THC/THCGeneral.cpp",
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/cuda/CUDASparseDescriptors.h
Expand Up @@ -123,6 +123,30 @@ class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
}
};

class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
: public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
public:
CuSparseDnVecDescriptor(const Tensor& input) {
// cuSPARSE doesn't support batched vectors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() == 1);

// cuSPARSE doesn't support non-contiguous vectors
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_non_overlapping_and_dense());

cudaDataType value_type = ScalarTypeToCudaDataType(input.scalar_type());
check_supported_cuda_type(value_type);

cusparseDnVecDescr_t raw_descriptor;
TORCH_CUDASPARSE_CHECK(cusparseCreateDnVec(
&raw_descriptor,
input.numel(),
input.data_ptr(),
value_type));
descriptor_.reset(raw_descriptor);
}
};

class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
: public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {};

Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/Blas.cpp
Expand Up @@ -13,7 +13,7 @@ TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec,
TORCH_CHECK(mat.size(1) == vec.size(0) && (mat.size(0) == self.numel() || self.numel() == 1),
"size mismatch, got ", self.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0));
auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self);
set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, mat.options(), names);
set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, vec.options(), names);
auto result = maybe_get_output(0);
//this check can fire for inplace op only, for all other versions result is guaranteed to be correct size
TORCH_CHECK(result.dim() == 1 && result.sizes()[0] == mat.sizes()[0], "output of addmv operation should be 1D with ",
Expand Down Expand Up @@ -91,14 +91,14 @@ Tensor &mv_out(const Tensor &self, const Tensor &vec, Tensor& result) {
//in addmv, because addmv expects self to satisfy proper conditions
//to avoid this, supply correctly sized self, its contents doesn't matter because beta is 0
if (result.dim() > 1 || (result.numel() != self.size(0) || result.numel() !=1)) {
Tensor self_addmv = at::empty({self.size(0)}, self.options());
Tensor self_addmv = at::empty({self.size(0)}, vec.options());
return at::addmv_out(result, self_addmv, self, vec, 0, 1);
}
return at::addmv_out(result, result, self, vec, 0, 1);
}

Tensor mv(const Tensor &self, const Tensor &vec) {
Tensor result = at::empty({self.size(0)}, self.options());
Tensor result = at::empty({self.size(0)}, vec.options());
//inplace version is more efficient if we can use it
return at::addmv_(result, self, vec, 0, 1);
}
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -441,6 +441,7 @@
dispatch:
CPU: addmv_out_cpu
CUDA: addmv_out_cuda
SparseCsrCUDA: addmv_out_sparse_csr_cuda

- func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function, method
Expand Down Expand Up @@ -3050,8 +3051,8 @@
- func: mv(Tensor self, Tensor vec) -> Tensor
variants: function, method
dispatch:
CPU, CUDA: mv
SparseCPU, SparseCUDA, SparseCsrCPU, SparseCsrCUDA: mv_sparse
CPU, CUDA, SparseCsrCUDA: mv
SparseCPU, SparseCUDA, SparseCsrCPU: mv_sparse

- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down
61 changes: 61 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
@@ -0,0 +1,61 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDASparse.h>
#include <ATen/native/Resize.h>
#include <ATen/native/sparse/cuda/SparseBlasImpl.h>

#include <c10/util/MaybeOwned.h>

namespace at {
namespace native {

Tensor& addmv_out_sparse_csr_cuda(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, Tensor& result) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.is_sparse_csr());

TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D");
TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D");

TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat, "mat", 2}, {vec, "vec", 3}};
checkAllSameGPU(__func__, args);

c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
auto betaval = beta.toComplexDouble();

if (&result != &self) {
at::native::resize_output(result, self_->sizes());
if (betaval != 0.0) {
at::native::copy_(result, *self_);
}
}

if (mat._nnz() == 0) {
// shortcut for an empty matrix
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (betaval == 0.0) {
return result.zero_();
} else {
return at::mul_out(
const_cast<Tensor&>(result),
self,
at::native::scalar_tensor(
beta, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */));
}
}

// cuda 11.3 version computes garbage for float16 inputs
// couldn't check bfloat16 because it requires Ampere GPU but I assume the problem is same
// addmm works fine
if (vec.scalar_type() == kHalf || vec.scalar_type() == kBFloat16) {
result.unsqueeze_(-1);
sparse::cuda::impl::addmm_out_sparse_csr_dense_cuda_impl(mat, vec.unsqueeze(-1), beta, alpha, result);
result.squeeze_(-1);
return result;
}

sparse::cuda::impl::addmv_out_sparse_csr_cuda_impl(mat, vec, beta, alpha, result);
return result;
}

}
}
96 changes: 96 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
Expand Up @@ -29,6 +29,16 @@ c10::MaybeOwned<Tensor> inline prepare_dense_matrix_for_cusparse(
}
}

c10::MaybeOwned<Tensor> inline prepare_dense_vector_for_cusparse(
const Tensor& tensor) {
if (tensor.is_non_overlapping_and_dense()) {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
} else {
return c10::MaybeOwned<Tensor>::owned(
tensor.clone(at::MemoryFormat::Contiguous));
}
}

} // anonymous namespace

void addmm_out_sparse_csr_dense_cuda_impl(
Expand Down Expand Up @@ -129,6 +139,92 @@ void addmm_out_sparse_csr_dense_cuda_impl(
#endif
}

/*
Computes a sparse matrix-dense vector product defined as
y <- alpha*op(A)*x + beta*y
Args:
* `mat` - Tensor storing sparse m x n matrix A.
* `vec` - Tensor storing dense vector x of size n.
* `result` - [in] Tensor storing dense vector y of size m.
[out] result of the operation.
*/
void addmv_out_sparse_csr_cuda_impl(
const at::sparse_csr::SparseCsrTensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
#if !AT_USE_CUSPARSE_GENERIC_API()
TORCH_CHECK(
false,
"Calling addmv on a sparse GPU tensor requires compiling ",
"PyTorch with CUDA 10.2+ (CUDA 11+ on Windows). ",
"Please use PyTorch built with newer CUDA version.");
#else
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;

c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result);
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec);

// TODO: update this to support COO sparse layout
auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat);
auto descX = at::cuda::sparse::CuSparseDnVecDescriptor(*vec_);
auto descY = at::cuda::sparse::CuSparseDnVecDescriptor(*result_);

// There is no dispatch for kHalf and kBFloat16 types because cusparse
// computes garbage in this case, latest checked version of cuda is 11.3
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(),
"addmv_out_sparse_csr_cuda_impl",
[&] {
auto beta_ = beta.to<scalar_t>();
auto alpha_ = alpha.to<scalar_t>();
cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
auto handle = at::cuda::getCurrentCUDASparseHandle();

// cusparseSpMVAlg_t was updated in cuda 11.2.1
#if CUSPARSE_VERSION >= 11400
cusparseSpMVAlg_t alg = CUSPARSE_SPMV_ALG_DEFAULT;
#else
cusparseSpMVAlg_t alg = CUSPARSE_MV_ALG_DEFAULT;
#endif

size_t buffer_size;
TORCH_CUDASPARSE_CHECK(cusparseSpMV_bufferSize(
handle,
opA,
&alpha_,
descA.descriptor(),
descX.descriptor(),
&beta_,
descY.descriptor(),
compute_type,
alg,
&buffer_size // output
));

auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto work_data = allocator.allocate(buffer_size);

TORCH_CUDASPARSE_CHECK(cusparseSpMV(
handle,
opA,
&alpha_,
descA.descriptor(),
descX.descriptor(),
&beta_,
descY.descriptor(),
compute_type,
alg,
work_data.get()));
});
if (!result.is_same(*result_)) {
result.copy_(*result_);
}
#endif
}

} // namespace impl
} // namespace cuda
} // namespace sparse
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
Expand Up @@ -17,6 +17,13 @@ void addmm_out_sparse_csr_dense_cuda_impl(
const Scalar& alpha,
const Tensor& result);

void addmv_out_sparse_csr_cuda_impl(
const at::sparse_csr::SparseCsrTensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result);

}
} // namespace cuda
} // namespace sparse
Expand Down
13 changes: 11 additions & 2 deletions test/test_sparse_csr.py
Expand Up @@ -388,7 +388,11 @@ def test_matmul_device_mismatch(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
torch.addmm(s, csr, m2)

@dtypes(torch.float, torch.double)
@skipCUDAIfNoCusparseGeneric
@dtypes(*torch.testing.floating_types())
@dtypesIfCUDA(*get_all_complex_dtypes(),
*get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater))
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
def test_csr_matvec(self, device, dtype):
side = 100
for index_dtype in [torch.int32, torch.int64]:
Expand All @@ -401,7 +405,12 @@ def test_csr_matvec(self, device, dtype):
self.assertEqual(res, expected)

bad_vec = torch.randn(side + 10, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "mv: expected"):
err_msg = "mv: expected"
# CUDA path now uses generic meta/structured implementation
# TODO: move CPU path to not use `mv_sparse` function
if self.device_type == 'cuda':
err_msg = "size mismatch, got"
with self.assertRaisesRegex(RuntimeError, err_msg):
csr.matmul(bad_vec)

@dtypes(torch.double)
Expand Down

0 comments on commit 1b2e495

Please sign in to comment.