diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 874207c7..9b323d50 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,4 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean +import CUDA.CUSPARSE: AbstractCuSparseArray function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -43,7 +44,10 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn return nothing end -function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP + +function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -63,7 +67,9 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArra return dst end -function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) +function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -74,7 +80,7 @@ end ## Gradients function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -94,7 +100,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -114,7 +120,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -160,8 +166,8 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AnyCuArray{Tsrc,Nsrc}, - idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + src::Union{AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray}, + idx::Union{AnyCuArray{Tidx,Nidx},AbstractCuSparseArray}) where {Tsrc,Tidx,Nsrc,Nidx} dims = Nsrc - Nidx Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx)