-
-
Notifications
You must be signed in to change notification settings - Fork 130
Open
Description
Motivation and description
While trying to call scatter(+, dst, src, idxs) where src is a CUDA sparse matrix (CUDA.CUSPARSE.CuSparseMatrixCSC to be precise) I noticed that the GPU kernel isn't called. This is because CuSparseMatrixCSC is not a subtype of CUDA.AnyCuArray (not sure if I should report this on CUDA.jl or if it's intended behaviour)
Would it make sense to add a special dispatch for sparse arguments? It could simply call the existing kernels, but that way it wouldn't end up calling the CPU version. Perhaps even rely on cuDNN.
Alternatively, would it perhaps make sense to make an upstream issue so that CuSparseMatrixCSC <: AnyCuArray?
Possible Implementation
add a new method
function NNlib.scatter!(op::OP, dst::Union{CuSparseMatrixCSC, AnyCuArray}, src::...)
# Possibly call cuDNN or some other CUDA library that implements an optimized sparse version
# or simply call the existing `scatter_kernel!`Metadata
Metadata
Assignees
Labels
No labels