Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scatter operations #255

Merged
merged 57 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
a0a5841
add scatter operations
yuehhua Dec 25, 2020
81ff677
Generalize scatter with op argument and accept AbstractArray
yuehhua Dec 27, 2020
b2b1b58
Update src/scatter.jl
yuehhua Dec 27, 2020
2a9f250
Update src/scatter.jl
yuehhua Dec 27, 2020
73a03f0
Update src/utils.jl
yuehhua Dec 27, 2020
a9ffc85
Separate scatter mean doc
yuehhua Dec 27, 2020
b102691
Drop redundant line of code
yuehhua Dec 27, 2020
99a5998
Update doc and variable names
yuehhua Dec 27, 2020
6478cbc
Fix variable name
yuehhua Dec 27, 2020
aa02132
Update docs
yuehhua Dec 28, 2020
f063cb9
Add gather
yuehhua Dec 28, 2020
1731b1b
Add least_dims
yuehhua Dec 28, 2020
e50b808
Support index represented in tuple
yuehhua Dec 29, 2020
436e38f
Merge gather! functions
yuehhua Dec 30, 2020
582b141
Fix typo
yuehhua Dec 30, 2020
3768de1
Scatter: scalar version
yuehhua Dec 30, 2020
09ed77a
Add dims argument
yuehhua Dec 31, 2020
f5f2c93
Update src/scatter.jl
yuehhua Jan 3, 2021
0464869
Update src/gather.jl
yuehhua Jan 3, 2021
57e3dce
Do not need colons
yuehhua Jan 3, 2021
a5a6630
Update src/scatter.jl
yuehhua Jan 3, 2021
625de76
Support gather for scalar and array version
yuehhua Jan 4, 2021
12c526c
Add bound checks for gather
yuehhua Jan 4, 2021
40e993c
Temporally drop gather_indices
yuehhua Jan 4, 2021
09c43b8
Make dims as kwargs
yuehhua Jan 5, 2021
7f07448
Fix bug
yuehhua Jan 5, 2021
2f9d8c2
Add scatter function
yuehhua Jan 6, 2021
24a19ca
Draft for scatter gradient
yuehhua Jan 10, 2021
6651a3e
change zygote-style to chainrules-style and refactor
yuehhua Jan 13, 2021
1356965
replace nothing with DoesNotExist
yuehhua Jan 13, 2021
1f80035
correct with NO_FIELDS
yuehhua Jan 13, 2021
70a6409
fix code to be compatible with v1.3
yuehhua Jan 13, 2021
2bda28c
extract ∇scatter_src
yuehhua Jan 13, 2021
8e3ad4a
extract ∇scatter_dst
yuehhua Jan 13, 2021
882a953
bug fix
yuehhua Jan 26, 2021
91d8c30
add test_rrule for testing gradient
yuehhua Feb 18, 2021
a577725
rewrite gather
CarloLucibello Feb 24, 2021
cd7c449
project cleanup
CarloLucibello Feb 24, 2021
0453f0c
add tests for scatter! and scatter
yuehhua Feb 24, 2021
28bd91f
add dimensional check for output arrays
yuehhua Feb 25, 2021
f9c5647
move gradient of scatter to another PR
yuehhua Feb 25, 2021
cdb55fc
fix conflict
yuehhua Feb 25, 2021
cebf473
move gather implementation to other PR
yuehhua Feb 25, 2021
5dfc49b
remove Compat
yuehhua Feb 26, 2021
3783697
remove dims args
yuehhua Feb 28, 2021
4f11c40
add @inbounds back
yuehhua Feb 28, 2021
010af70
remove restriction of numerical types
yuehhua Mar 1, 2021
66ded23
remove type promotion
yuehhua Mar 1, 2021
dcc6710
remove inbounds and simd annotations
yuehhua Mar 1, 2021
8faca3d
update error message
yuehhua Mar 1, 2021
e7913b0
remove @boundscheck
yuehhua Mar 1, 2021
e4f0c17
fix
yuehhua Mar 1, 2021
ec402d9
replace zeros and ones with more generic way
yuehhua Mar 1, 2021
c6213c6
remove bound checks
yuehhua Mar 2, 2021
4d5cbe8
optimize
yuehhua Mar 2, 2021
fc7360d
update docs
yuehhua Mar 3, 2021
6542ea9
remove not used utilities
yuehhua Mar 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ using Pkg
using Requires
using ChainRulesCore
using Base.Broadcast: broadcasted
using Statistics: mean

const IntOrTuple = Union{Integer,Tuple}
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

# Include APIs
Expand All @@ -31,6 +33,8 @@ include("conv.jl")
include("conv_bias_act.jl")
include("pooling.jl")
include("upsample.jl")
include("utils.jl")
include("scatter.jl")

## Include implementations
include("impl/padding_edges.jl")
Expand Down
187 changes: 187 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
export scatter!, scatter

## Scatter API
# - Scatter:
# - scatter(op, src, idx)
# - scatter!(op, dst, src, idx)
# - Scatter destination backpropagation
# - ∇scatter_dst!
# - Scatter source backpropagation
# - ∇scatter_src
# - ∇scatter_src!
#

function _check_dims(Ndst, Nsrc, N, Nidx)
@assert Ndst - N == Nsrc - Nidx
dims = Ndst - N
if dims < 0
throw(ArgumentError("dims must be non-negative but got dims=$dims"))
end
return dims
end

_check_input(idx::AbstractArray{<:Integer}, arr) = checkbounds(arr, minimum(idx):maximum(idx))

function _check_input(idx::AbstractArray{<:Tuple}, arr)
pairs = map(xs -> Base.OneTo(maximum(xs)), zip(idx...))
checkbounds(arr, pairs...)
end

function _check_output(idx::AbstractArray{<:IntOrTuple}, arr, src, dims)
pre_dims = axes(src)[1:dims]
post_dims = Base.OneTo.(maximum_dims(idx))
checkbounds(arr, pre_dims..., post_dims...)
end


"""
scatter!(op, dst, src, idx)

Scatter operation, which scatters data in `src` and assigns to `dst` according to `idx`.
With the data going to the same place, specified operation is applied on to reduce data.
For each index `k` in `idx`, accumulate values in `dst` according to

dst[idx[k]...] = (op).(dst[idx[k]...], src[k...])
yuehhua marked this conversation as resolved.
Show resolved Hide resolved

# Arguments
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max` and `min`.
- `dst`: the destination for `src` to aggregate to. This argument will be mutated.
- `src`: the source data for aggregating.
- `idx`: the mapping for aggregation from source (index) to destination (value).
The index of `idx` is corresponding to the index of `src` and the value of `idx` is
corresponding to the index of `dst`. The value of `idx` can be `Int` or `Tuple` type.

dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...])
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
"""
function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:Integer,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
dims = _check_dims(Ndst, Nsrc, 1, Nidx)
@boundscheck _check_output(idx, dst, src, dims)
@boundscheck _check_input(idx, src)
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{NTuple{N,Int},Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,N,Nidx}
dims = _check_dims(Ndst, Nsrc, N, Nidx)
@boundscheck _check_output(idx, dst, src, dims)
@boundscheck _check_input(idx, src)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
colons = Base.ntuple(_->Colon(), N)
for k in CartesianIndices(idx)
dst_v = view(dst, colons..., idx[k]...)
src_v = view(src, colons..., k)
dst_v .= (op).(dst_v, src_v)
end
dst
end

"""
scatter!(mean, dst, src, idx)

Scatter mean operation, which scatters data in `src` and assigns to `dst` according to `idx`.
With the data going to the same place, mean is applied on to reduce data.
For each index `k` in `idx`, accumulate values in `dst` according to

dst[idx[k]...] = dst[idx[k]...] + mean.(src[k...])

# Arguments
- `dst`: the destination for `src` to aggregate to. This argument will be mutated.
- `src`: the source data for aggregating.
- `idx`: the mapping for aggregation from source (index) to destination (value).
The index of `idx` is corresponding to the index of `src` and the value of `idx` is
corresponding to the index of `dst`. The value of `idx` can be `Int` or `Tuple` type.

dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...])
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
"""
function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:Integer,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
return dst
end

function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{NTuple{N,Int},Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,N,Nidx}
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
return dst
end


"""
scatter(op, src, idx)

Scatter operation, which applies specified operation on `src` according to `idx`
and gives an new array `dst`.
For each index `k` in `idx`, accumulate values in `dst` according to

dst[idx[k]...] = (op).(src[k...])

# Arguments
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max` and `min`.
- `src`: the source data for aggregating.
- `idx`: the mapping for aggregation from source (index) to destination (value).
The index of `idx` is corresponding to the index of `src` and the value of `idx` is
corresponding to the index of `dst`. The value of `idx` can be `Int` or `Tuple` type.
"""
function scatter end

for op in [+, -]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dst = zeros(T, size(src)[1:dims]..., maximum_dims(idx)...)
scatter!(op, dst, src, idx)
end
end

for op in [*, /]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dst = ones(T, size(src)[1:dims]..., maximum_dims(idx)...)
scatter!(op, dst, src, idx)
end
end

function scatter(op::typeof(max),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dst = fill(typemin(T), size(src)[1:dims]..., maximum_dims(idx)...)
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(min),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dst = fill(typemax(T), size(src)[1:dims]..., maximum_dims(idx)...)
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(mean),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
FT = float(T)
dims = Nsrc - Nidx
dst = zeros(FT, size(src)[1:dims]..., maximum_dims(idx)...)
scatter!(op, dst, FT.(src), idx)
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
end
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
42 changes: 42 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
safe_div(x, y)

Safely divide `x` by `y`. If `y` is zero, return `x` directly.
"""
safe_div(x, y) = ifelse(iszero(y), x, x/y)

"""
maximum_dims(dims)

Return the maximum value for each dimension. An array of dimensions `dims` is accepted.
The maximum of each dimension in the element is computed.
"""
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )

function maximum_dims(dims::AbstractArray{<:Tuple})
Tuple(maximum(xs) for xs in zip(dims...))
end

function reverse_indices(X::Array{T}) where T
Y = Dict{T,Vector{CartesianIndex}}()
@inbounds for (ind, val) = pairs(X)
Y[val] = get(Y, val, CartesianIndex[])
push!(Y[val], ind)
end
Y
end

function count_indices(idx::AbstractArray, N)
counts = zero.(idx)
@inbounds for i = 1:N
counts += sum(idx.==i) * (idx.==i)
end
counts
end

function divide_by_counts!(xs, idx::AbstractArray, N)
counts = count_indices(idx, N)
@inbounds for ind = CartesianIndices(counts)
view(xs, :, ind) ./= counts[ind]
end
end
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@ end
@testset "Upsampling" begin
include("upsample.jl")
end

@testset "Scatter" begin
include("scatter.jl")
end

@testset "Utilities" begin
include("utils.jl")
end
Loading