forked from pytorch/pytorch
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds `addmv_out_sparse_csr_cuda`. The operation is used to compute matrix-vector multiplication. Since structured_delegate is used we only need to implement the out variant, the in-place and normal variants are autogenerated. Working on this PR revealed that float16 (and probably bfloat16) inputs do not work correctly in cusparse, therefore for this case `addmm` is used with squeezes and unsqueezes. ghstack-source-id: 05379cadceaccd015195d9e29fde0829ca84cbe1 Pull Request resolved: pytorch#61407
- Loading branch information
1 parent
87e5e84
commit 7e78952
Showing
8 changed files
with
206 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/cuda/CUDASparse.h> | ||
#include <ATen/native/Resize.h> | ||
#include <ATen/native/sparse/cuda/SparseBlasImpl.h> | ||
|
||
#include <c10/util/MaybeOwned.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
Tensor& addmv_out_sparse_csr_cuda(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, Tensor& result) { | ||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.is_sparse_csr()); | ||
|
||
TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D"); | ||
TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D"); | ||
|
||
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat, "mat", 2}, {vec, "vec", 3}}; | ||
checkAllSameGPU(__func__, args); | ||
|
||
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)}); | ||
auto betaval = beta.toComplexDouble(); | ||
|
||
if (&result != &self) { | ||
at::native::resize_output(result, self_->sizes()); | ||
if (betaval != 0.0) { | ||
at::native::copy_(result, *self_); | ||
} | ||
} | ||
|
||
if (mat._nnz() == 0) { | ||
// shortcut for an empty matrix | ||
// By definition, when beta==0, values in self should be ignored. nans and infs | ||
// should not propagate | ||
if (betaval == 0.0) { | ||
return result.zero_(); | ||
} else { | ||
return at::mul_out( | ||
const_cast<Tensor&>(result), | ||
self, | ||
at::native::scalar_tensor( | ||
beta, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */)); | ||
} | ||
} | ||
|
||
// cuda 11.3 version computes garbage for float16 inputs | ||
// couldn't check bfloat16 because it requires Ampere GPU but I assume the problem is same | ||
// addmm works fine | ||
if (vec.scalar_type() == kHalf || vec.scalar_type() == kBFloat16) { | ||
result.unsqueeze_(-1); | ||
sparse::cuda::impl::addmm_out_sparse_csr_dense_cuda_impl(mat, vec.unsqueeze(-1), beta, alpha, result); | ||
result.squeeze_(-1); | ||
return result; | ||
} | ||
|
||
sparse::cuda::impl::addmv_out_sparse_csr_cuda_impl(mat, vec, beta, alpha, result); | ||
return result; | ||
} | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters