From 472122ece30cb7388b0a35ab82daa75ba9ee4f14 Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:13:33 +0100 Subject: [PATCH 1/4] Expanded dispatch of scatter! to include AbstractCuSparseArray --- ext/NNlibCUDAExt/scatter.jl | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 874207c77..15b7c1a08 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,4 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean +import CUDA: CUDA.CUSPARSE.AbstractcuSparseArray function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -23,7 +24,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) + j, k = divrem(index - 1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) end @@ -31,11 +32,11 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size end function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - max_idx, max_dims_idx, dims_size) where OP + max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) + j, k = divrem(index - 1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...) CUDA.@atomic dst[li] = op(dst[li], src[index]) @@ -43,7 +44,10 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn return nothing end -function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP + +function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{SpAnyCuArray,AbstractCuSparseArray}, + idx::Union{SpAnyCuArray,AbstractCuSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -55,7 +59,7 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArra op, dst, src, idx, max_idx, max_dims_idx, dims_size end - kernel = @cuda launch=false scatter_kernel!(args...) + kernel = @cuda launch = false scatter_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) @@ -63,7 +67,8 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArra return dst end -function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) +function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{AnyCuArray,AbstractCuArray}, idx::Union{AnyCuArray,AbstractCuArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -74,7 +79,7 @@ end ## Gradients function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -94,7 +99,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -114,7 +119,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -137,7 +142,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -160,8 +165,8 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AnyCuArray{Tsrc,Nsrc}, - idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + src::Union{AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray}, + idx::Union{AnyCuArray{Tidx,Nidx},AbstractCuSparseArray}) where {Tsrc,Tidx,Nsrc,Nidx} dims = Nsrc - Nidx Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx) @@ -177,7 +182,7 @@ function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc end - kernel = @cuda launch=false ∇scatter_src_kernel!(args...) + kernel = @cuda launch = false ∇scatter_src_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) From cfccdb3037ba46bf29dd28660a6764bd0341721d Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:28:17 +0100 Subject: [PATCH 2/4] Fix incorrect import --- ext/NNlibCUDAExt/scatter.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 15b7c1a08..29e135a8d 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,5 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean -import CUDA: CUDA.CUSPARSE.AbstractcuSparseArray +import CUDA.CUSPARSE: AbstractCuSparseArray function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x From 1551922af093e11ef9f343ad9b2ee357216dcc1f Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:40:21 +0100 Subject: [PATCH 3/4] restore formatting --- ext/NNlibCUDAExt/scatter.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 29e135a8d..485fca39b 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -24,7 +24,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index - 1, max_dims_idx) + j, k = divrem(index-1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) end @@ -32,11 +32,11 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size end function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - max_idx, max_dims_idx, dims_size) where OP + max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index - 1, max_dims_idx) + j, k = divrem(index-1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...) CUDA.@atomic dst[li] = op(dst[li], src[index]) @@ -59,7 +59,7 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, op, dst, src, idx, max_idx, max_dims_idx, dims_size end - kernel = @cuda launch = false scatter_kernel!(args...) + kernel = @cuda launch=false scatter_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) @@ -99,7 +99,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -142,7 +142,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -182,7 +182,7 @@ function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc end - kernel = @cuda launch = false ∇scatter_src_kernel!(args...) + kernel = @cuda launch=false ∇scatter_src_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) From c749dfcc11449eb604504b462e0bab676bf06d71 Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:58:59 +0100 Subject: [PATCH 4/4] Typos --- ext/NNlibCUDAExt/scatter.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 485fca39b..9b323d504 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -46,8 +46,8 @@ end function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{SpAnyCuArray,AbstractCuSparseArray}, - idx::Union{SpAnyCuArray,AbstractCuSparseArray}) where OP + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -68,7 +68,8 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, end function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{AnyCuArray,AbstractCuArray}, idx::Union{AnyCuArray,AbstractCuArray}) + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns)