Skip to content

Commit

Permalink
Implement accumulate and friends (#702)
Browse files Browse the repository at this point in the history
* Implement accumulate and friends

* Run tests for accumulate

* Skip inference tests in Julia 1.1

* Update src/mapreduce.jl

Co-Authored-By: Chris Foster <chris42f@gmail.com>

* Rename: _maybeval -> _maybe_val

* Explain how `_map` is used from `_accumulate`

* Revert: (push(ys, y), y)

This reverts commit 4ca0144.

* Comment on why we use `vcat`

* Use inference to determine element types

* Use reduce_empty in cumsum/cumprod for Array-compatibility

* Use reduce_first instead of reduce_empty

Co-authored-by: Chris Foster <chris42f@gmail.com>
  • Loading branch information
tkf and c42f committed Feb 18, 2020
1 parent e48d2f0 commit 7a78efd
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
50 changes: 50 additions & 0 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,53 @@ end
@inbounds return similar_type(a, T, Size($Snew))(tuple($(exprs...)))
end
end

struct _InitialValue end

_maybe_val(dims::Integer) = Val(Int(dims))
_maybe_val(dims) = dims
_valof(::Val{D}) where D = D

@inline Base.accumulate(op::F, a::StaticVector; dims = :, init = _InitialValue()) where {F} =
_accumulate(op, a, _maybe_val(dims), init)

@inline Base.accumulate(op::F, a::StaticArray; dims, init = _InitialValue()) where {F} =
_accumulate(op, a, _maybe_val(dims), init)

@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F}
# Adjoin the initial value to `op`:
rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)

if isempty(a)
T = return_type(rf, Tuple{typeof(init), eltype(a)})
return similar_type(a, T)()
end

# StaticArrays' `reduce` is `foldl`:
results = _reduce(
a,
dims,
(init = (similar_type(a, Union{}, Size(0))(), init),),
) do (ys, acc), x
y = rf(acc, x)
# Not using `push(ys, y)` here since we need to widen element type as
# we iterate.
(vcat(ys, SA[y]), y)
end
dims === (:) && return first(results)

ys = map(first, results)
# Now map over all indices of `a`. Since `_map` needs at least
# one `StaticArray` to be passed, we pass `a` here, even though
# the values of `a` are not used.
data = _map(a, CartesianIndices(a)) do _, CI
D = _valof(dims)
I = Tuple(CI)
J = setindex(I, 1, D)
ys[J...][I[D]]
end
return similar_type(a, eltype(data))(data)
end

@inline Base.cumsum(a::StaticArray; kw...) = accumulate(Base.add_sum, a; kw...)
@inline Base.cumprod(a::StaticArray; kw...) = accumulate(Base.mul_prod, a; kw...)
66 changes: 66 additions & 0 deletions test/accumulate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using StaticArrays, Test

@testset "accumulate" begin
@testset "cumsum(::$label)" for (label, T) in [
# label, T
("SVector", SVector),
("MVector", MVector),
("SizedVector", SizedVector),
]
@testset "$label" for (label, a) in [
("[1, 2, 3]", T{3}(SA[1, 2, 3])),
("[]", T{0,Int}(())),
]
@test cumsum(a) == cumsum(collect(a))
@test cumsum(a) isa similar_type(a)
@inferred cumsum(a)
end
@test eltype(cumsum(T{0,Int8}(()))) == eltype(cumsum(Int8[]))
@test eltype(cumsum(T{1,Int8}((1)))) == eltype(cumsum(Int8[1]))
@test eltype(cumsum(T{2,Int8}((1, 2)))) == eltype(cumsum(Int8[1, 2]))
end

@testset "cumsum(::$label; dims=2)" for (label, T) in [
# label, T
("SMatrix", SMatrix),
("MMatrix", MMatrix),
("SizedMatrix", SizedMatrix),
]
@testset "$label" for (label, a) in [
("[1 2; 3 4; 5 6]", T{3,2}(SA[1 2; 3 4; 5 6])),
("0 x 2 matrix", T{0,2,Float64}()),
("2 x 0 matrix", T{2,0,Float64}()),
]
@test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2)
@test cumsum(a; dims = 2) isa similar_type(a)
v"1.1" <= VERSION < v"1.2" && continue
@inferred cumsum(a; dims = Val(2))
end
end

@testset "cumsum(a::SArray; dims=$i); ndims(a) = $d" for d in 1:4, i in 1:d
shape = Tuple(1:d)
a = similar_type(SArray, Int, Size(shape))(1:prod(shape))
@test cumsum(a; dims = i) == cumsum(collect(a); dims = i)
@test cumsum(a; dims = i) isa SArray
v"1.1" <= VERSION < v"1.2" && continue
@inferred cumsum(a; dims = Val(i))
end

@testset "cumprod" begin
a = SA[1, 2, 3]
@test cumprod(a)::SArray == cumprod(collect(a))
@inferred cumprod(a)

@test eltype(cumsum(SA{Int8}[])) == eltype(cumsum(Int8[]))
@test eltype(cumsum(SA{Int8}[1])) == eltype(cumsum(Int8[1]))
@test eltype(cumsum(SA{Int8}[1, 2])) == eltype(cumsum(Int8[1, 2]))
end

@testset "empty vector with init" begin
a = SA{Int}[]
right(_, x) = x
@test accumulate(right, a; init = Val(1)) === SA{Int}[]
@inferred accumulate(right, a; init = Val(1))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("abstractarray.jl")
include("indexing.jl")
include("initializers.jl")
Random.seed!(42); include("mapreduce.jl")
Random.seed!(42); include("accumulate.jl")
Random.seed!(42); include("arraymath.jl")
include("broadcast.jl")
include("linalg.jl")
Expand Down

0 comments on commit 7a78efd

Please sign in to comment.