forked from pytorch/pytorch
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds `addmv_out_sparse_csr_cuda`. The operation is used to compute matrix-vector multiplication. Since structured_delegate is used we only need to implement the out variant, the in-place and normal variants are autogenerated. Working on this PR revealed that float16 (and probably bfloat16) inputs do not work correctly in cusparse, therefore for this case `addmm` is used with squeezes and unsqueezes. ghstack-source-id: 31acf911b4d19503647769d5a4f512c4679a0a8e Pull Request resolved: pytorch#61407
- Loading branch information
1 parent
3b25d9b
commit 4315f17
Showing
8 changed files
with
206 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/cuda/CUDASparse.h> | ||
#include <ATen/native/Resize.h> | ||
#include <ATen/native/sparse/cuda/SparseBlasImpl.h> | ||
|
||
#include <c10/util/MaybeOwned.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
Tensor& addmv_out_sparse_csr_cuda(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, Tensor& result) { | ||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.is_sparse_csr()); | ||
|
||
TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D"); | ||
TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D"); | ||
|
||
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat, "mat", 2}, {vec, "vec", 3}}; | ||
checkAllSameGPU(__func__, args); | ||
|
||
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)}); | ||
auto betaval = beta.toComplexDouble(); | ||
|
||
if (&result != &self) { | ||
at::native::resize_output(result, self_->sizes()); | ||
if (betaval != 0.0) { | ||
at::native::copy_(result, *self_); | ||
} | ||
} | ||
|
||
if (mat._nnz() == 0) { | ||
// shortcut for an empty matrix | ||
// By definition, when beta==0, values in self should be ignored. nans and infs | ||
// should not propagate | ||
if (betaval == 0.0) { | ||
return result.zero_(); | ||
} else { | ||
return at::mul_out( | ||
const_cast<Tensor&>(result), | ||
self, | ||
at::native::scalar_tensor( | ||
beta, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */)); | ||
} | ||
} | ||
|
||
// cuda 11.3 version computes garbage for float16 inputs | ||
// couldn't check bfloat16 because it requires Ampere GPU but I assume the problem is same | ||
// addmm works fine | ||
if (vec.scalar_type() == kHalf || vec.scalar_type() == kBFloat16) { | ||
result.unsqueeze_(-1); | ||
sparse::impl::cuda::addmm_out_sparse_csr(mat, vec.unsqueeze(-1), beta, alpha, result); | ||
result.squeeze_(-1); | ||
return result; | ||
} | ||
|
||
sparse::impl::cuda::addmv_out_sparse_csr(mat, vec, beta, alpha, result); | ||
return result; | ||
} | ||
|
||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters