Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scatter operations #255

Merged
merged 57 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
a0a5841
add scatter operations
yuehhua Dec 25, 2020
81ff677
Generalize scatter with op argument and accept AbstractArray
yuehhua Dec 27, 2020
b2b1b58
Update src/scatter.jl
yuehhua Dec 27, 2020
2a9f250
Update src/scatter.jl
yuehhua Dec 27, 2020
73a03f0
Update src/utils.jl
yuehhua Dec 27, 2020
a9ffc85
Separate scatter mean doc
yuehhua Dec 27, 2020
b102691
Drop redundant line of code
yuehhua Dec 27, 2020
99a5998
Update doc and variable names
yuehhua Dec 27, 2020
6478cbc
Fix variable name
yuehhua Dec 27, 2020
aa02132
Update docs
yuehhua Dec 28, 2020
f063cb9
Add gather
yuehhua Dec 28, 2020
1731b1b
Add least_dims
yuehhua Dec 28, 2020
e50b808
Support index represented in tuple
yuehhua Dec 29, 2020
436e38f
Merge gather! functions
yuehhua Dec 30, 2020
582b141
Fix typo
yuehhua Dec 30, 2020
3768de1
Scatter: scalar version
yuehhua Dec 30, 2020
09ed77a
Add dims argument
yuehhua Dec 31, 2020
f5f2c93
Update src/scatter.jl
yuehhua Jan 3, 2021
0464869
Update src/gather.jl
yuehhua Jan 3, 2021
57e3dce
Do not need colons
yuehhua Jan 3, 2021
a5a6630
Update src/scatter.jl
yuehhua Jan 3, 2021
625de76
Support gather for scalar and array version
yuehhua Jan 4, 2021
12c526c
Add bound checks for gather
yuehhua Jan 4, 2021
40e993c
Temporally drop gather_indices
yuehhua Jan 4, 2021
09c43b8
Make dims as kwargs
yuehhua Jan 5, 2021
7f07448
Fix bug
yuehhua Jan 5, 2021
2f9d8c2
Add scatter function
yuehhua Jan 6, 2021
24a19ca
Draft for scatter gradient
yuehhua Jan 10, 2021
6651a3e
change zygote-style to chainrules-style and refactor
yuehhua Jan 13, 2021
1356965
replace nothing with DoesNotExist
yuehhua Jan 13, 2021
1f80035
correct with NO_FIELDS
yuehhua Jan 13, 2021
70a6409
fix code to be compatible with v1.3
yuehhua Jan 13, 2021
2bda28c
extract ∇scatter_src
yuehhua Jan 13, 2021
8e3ad4a
extract ∇scatter_dst
yuehhua Jan 13, 2021
882a953
bug fix
yuehhua Jan 26, 2021
91d8c30
add test_rrule for testing gradient
yuehhua Feb 18, 2021
a577725
rewrite gather
CarloLucibello Feb 24, 2021
cd7c449
project cleanup
CarloLucibello Feb 24, 2021
0453f0c
add tests for scatter! and scatter
yuehhua Feb 24, 2021
28bd91f
add dimensional check for output arrays
yuehhua Feb 25, 2021
f9c5647
move gradient of scatter to another PR
yuehhua Feb 25, 2021
cdb55fc
fix conflict
yuehhua Feb 25, 2021
cebf473
move gather implementation to other PR
yuehhua Feb 25, 2021
5dfc49b
remove Compat
yuehhua Feb 26, 2021
3783697
remove dims args
yuehhua Feb 28, 2021
4f11c40
add @inbounds back
yuehhua Feb 28, 2021
010af70
remove restriction of numerical types
yuehhua Mar 1, 2021
66ded23
remove type promotion
yuehhua Mar 1, 2021
dcc6710
remove inbounds and simd annotations
yuehhua Mar 1, 2021
8faca3d
update error message
yuehhua Mar 1, 2021
e7913b0
remove @boundscheck
yuehhua Mar 1, 2021
e4f0c17
fix
yuehhua Mar 1, 2021
ec402d9
replace zeros and ones with more generic way
yuehhua Mar 1, 2021
c6213c6
remove bound checks
yuehhua Mar 2, 2021
4d5cbe8
optimize
yuehhua Mar 2, 2021
fc7360d
update docs
yuehhua Mar 3, 2021
6542ea9
remove not used utilities
yuehhua Mar 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ using Pkg
using Requires
using ChainRulesCore
using Base.Broadcast: broadcasted
using Statistics: mean

const IntOrTuple = Union{Integer,Tuple}
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

# Include APIs
Expand All @@ -31,6 +33,8 @@ include("conv.jl")
include("conv_bias_act.jl")
include("pooling.jl")
include("upsample.jl")
include("utils.jl")
include("scatter.jl")

## Include implementations
include("impl/padding_edges.jl")
Expand Down
149 changes: 149 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
export scatter!, scatter

## Scatter API
# - Scatter:
# - scatter(op, src, idx)
# - scatter!(op, dst, src, idx)
# - Scatter destination backpropagation
# - ∇scatter_dst!
# - Scatter source backpropagation
# - ∇scatter_src
# - ∇scatter_src!
#

function _check_dims(Ndst, Nsrc, N, Nidx)
@assert Ndst - N == Nsrc - Nidx "Incompatible input shapes of (dst, src, idx) = ($Ndst, $Nsrc, $Nidx)."
dims = Ndst - N
if dims < 0
throw(ArgumentError("dims must be non-negative but got dims=$dims."))
end
return dims
end

typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M

"""
scatter!(op, dst, src, idx)

Scatter operation, which scatters data in `src` and assigns to `dst` according to `idx`.
With the data going to the same place, specified aggregate operation is applied on to reduce
data. For each index `k` in `idx`, accumulate values in `dst` according to

dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...])

# Arguments
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max`, `min`
and `mean`.
- `dst`: the destination for `src` to aggregate to. This argument will be mutated.
- `src`: the source data for aggregating.
- `idx`: the mapping for aggregation from source (index) to destination (value).
The index of `idx` is corresponding to the index of `src` and the dimensions of `idx` must
aligned with the last few dimensions of `src`. The value of `idx` is corresponding to the
index of `dst` and the value of `idx` must indicate the last few dimensions of `dst`.
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be
`Int` or `Tuple` type.
"""
function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrTuple,Ndst,Nsrc,Nidx}
M = typelength(Tidx)
dims = _check_dims(Ndst, Nsrc, M, Nidx)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
colons = Base.ntuple(_->Colon(), dims)
for k in CartesianIndices(idx)
dst_v = view(dst, colons..., idx[k]...)
src_v = view(src, colons..., k)
dst_v .= (op).(dst_v, src_v)
end
dst
end

function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
return dst
end


"""
scatter(op, src, idx)

Scatter operation, which applies specified operation on `src` according to `idx`
and gives an new array `dst`.
For each index `k` in `idx`, accumulate values in `dst` according to

dst[:, ..., idx[k]...] = (op).(src[:, ..., k...])

# Arguments
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max` and `min`.
- `src`: the source data for aggregating.
- `idx`: the mapping for aggregation from source (index) to destination (value).
The index of `idx` is corresponding to the index of `src` and the value of `idx` is
corresponding to the index of `dst`. The value of `idx` can be `Int` or `Tuple` type.
"""
function scatter end

for op in [+, -]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(+, T))
scatter!(op, dst, src, idx)
end
end

for op in [*, /]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(*, T))
scatter!(op, dst, src, idx)
end
end

function scatter(op::typeof(max),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemin(T))
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(min),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemax(T))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if typemax is the right thing to do here. The problem is if there are positions in dst which receive no contributions for src they will end up holding typemax, which doesn't seem meaningful. Maybe we should error out in such cases, but doing this check may have a performance impact

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I thought this issue before. Checking the position of dst is properly covered by idx is the way to avoid holding typemax. But still, it is necessary to check values in src is smaller than the value we assigned, either typemax or similar. similar gives the value existing in bare memory, so we have no idea knowing if the values are smaller enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we give the maximum of src? Thus, the value is at least smaller or equals to the maximum of src.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maximum(src) would be more surprising, in un-visited entries. typemax seems OK to me.

scatter!(op, dst, src, idx)
end

function scatter(op::typeof(mean),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
FT = float(T)
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(+, FT))
scatter!(op, dst, src, idx)
end
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
18 changes: 18 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
safe_div(x, y)

Safely divide `x` by `y`. If `y` is zero, return `x` directly.
"""
safe_div(x, y) = ifelse(iszero(y), x, x/y)

"""
maximum_dims(dims)

Return the maximum value for each dimension. An array of dimensions `dims` is accepted.
The maximum of each dimension in the element is computed.
"""
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )

function maximum_dims(dims::AbstractArray{<:Tuple})
Tuple(maximum(xs) for xs in zip(dims...))
end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@ end
@testset "Upsampling" begin
include("upsample.jl")
end

@testset "Scatter" begin
include("scatter.jl")
end

@testset "Utilities" begin
include("utils.jl")
end
175 changes: 175 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
dsts = Dict(
0 => [3, 4, 5, 6, 7],
1 => [3 3 4 4 5;
5 5 6 6 7],
)
srcs = Dict(
(0, true) => ones(Int, 3, 4),
(0, false) => ones(Int, 3) * collect(1:4)',
(1, true) => ones(Int, 2, 3, 4),
(1, false) => [1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4),
)
idxs = Dict(
:int => [1 2 3 4;
4 2 1 3;
3 5 5 3],
:tup => [(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)],
)
res = Dict(
(+, 0, true) => [5, 6, 9, 8, 9],
(+, 1, true) => [5 5 8 6 7;
7 7 10 8 9],
(+, 0, false) => [4, 4, 12, 5, 5],
(+, 1, false) => [4 4 12 5 5;
8 8 24 10 10],
(-, 0, true) => [1, 2, 1, 4, 5],
(-, 1, true) => [1 1 0 2 3;
3 3 2 4 5],
(-, 0, false) => [-4, -4, -12, -5, -5],
(-, 1, false) => [-4 -4 -12 -5 -5;
-8 -8 -24 -10 -10],
(max, 0, true) => [3, 4, 5, 6, 7],
(max, 1, true) => [3 3 4 4 5;
5 5 6 6 7],
(max, 0, false) => [3, 2, 4, 4, 3],
(max, 1, false) => [3 2 4 4 3;
6 4 8 8 6],
(min, 0, true) => [1, 1, 1, 1, 1],
(min, 1, true) => [1 1 1 1 1;
1 1 1 1 1],
(min, 0, false) => [1, 2, 1, 1, 2],
(min, 1, false) => [1 2 1 1 2;
2 4 2 2 4],
(*, 0, true) => [3, 4, 5, 6, 7],
(*, 1, true) => [3 3 4 4 5;
5 5 6 6 7],
(*, 0, false) => [3, 4, 48, 4, 6],
(*, 1, false) => [3 4 48 4 6;
12 16 768 16 24],
(/, 0, true) => [0.75, 1., 0.3125, 1.5, 1.75],
(/, 1, true) => [0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75],
(/, 0, false) => [1//3, 1//4, 1//48, 1//4, 1//6],
(/, 1, false) => [1//3 1//4 1//48 1//4 1//6;
1//12 1//16 1//768 1//16 1//24],
(mean, 0, true) => [4., 5., 6., 7., 8.],
(mean, 1, true) => [4. 4. 5. 5. 6.;
6. 6. 7. 7. 8.],
(mean, 0, false) => [2, 2, 3, 2.5, 2.5],
(mean, 1, false) => [2. 2. 3. 2.5 2.5;
4. 4. 6. 5. 5.],
)

types = [UInt8, UInt16, UInt32, UInt64, UInt128,
Int8, Int16, Int32, Int64, Int128, BigInt,
Float16, Float32, Float64, BigFloat, Rational]

@testset "scatter" begin
for T = types
@testset "$T" begin
PT = promote_type(T, Int)
@testset "+" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(+, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)])
@test scatter!(+, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(+, dims, mutated)])
@test scatter!(+, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(+, dims, mutated)])

mutated = false
@test scatter(+, T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)])
end
end

@testset "-" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(-, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)])
@test scatter!(-, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(-, dims, mutated)])
@test scatter!(-, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(-, dims, mutated)])

mutated = false
if !(T in [UInt8, UInt16, UInt32, UInt64, UInt128])
@test scatter(-, T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)])
end
end
end

@testset "max" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(max, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)])
@test scatter!(max, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(max, dims, mutated)])
@test scatter!(max, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(max, dims, mutated)])

mutated = false
if !(T in [BigInt])
@test scatter(max, T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)])
end
end
end

@testset "min" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(min, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)])
@test scatter!(min, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(min, dims, mutated)])
@test scatter!(min, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(min, dims, mutated)])

mutated = false
if !(T in [BigInt])
@test scatter(min, T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)])
end
end
end

@testset "*" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(*, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)])
@test scatter!(*, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(*, dims, mutated)])
@test scatter!(*, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(*, dims, mutated)])

mutated = false
if !(T in [UInt8, Int8])
@test scatter(*, T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)])
end
end
end
end
end

for T = [Float16, Float32, Float64, BigFloat, Rational]
@testset "$T" begin
PT = promote_type(T, Float64)
@testset "/" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == T.(res[(/, dims, mutated)])
@test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == PT.(res[(/, dims, mutated)])
@test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == PT.(res[(/, dims, mutated)])

mutated = false
@test scatter(/, T.(srcs[(dims, mutated)]), idx) == T.(res[(/, dims, mutated)])
end
end

@testset "mean" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(mean, T.(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)])
@test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == PT.(res[(mean, dims, mutated)])
@test scatter!(mean, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(mean, dims, mutated)])

mutated = false
@test scatter(mean, T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)])
end
end
end
end

@test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int])
idx = [1 2 3 4; 4 2 1 3; 6 7 8 9]
@test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx)
end
9 changes: 9 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testset "maximum_dims" begin
ind1 = [1,2,3,4,5,6]
@test NNlib.maximum_dims(ind1) == (6,)
ind2 = [(3,4,5), (1,2,3), (2,3,9)]
@test NNlib.maximum_dims(ind2) == (3,4,9)
ind3 = [(3,4,5) (1,2,3) (2,3,9);
(4,6,2) (5,3,2) (4,4,4)]
@test NNlib.maximum_dims(ind3) == (5,6,9)
end