From 67d3602d96b4179ee179dd088ab2886c70608eb3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 16 Aug 2017 00:39:18 +0200 Subject: [PATCH 1/3] Type stable implementation of ArrayPartition --- src/array_partition.jl | 308 +++++++++++++++++++++++++++++++--------- test/partitions_test.jl | 30 +++- 2 files changed, 267 insertions(+), 71 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index bb7c4fa3..8d9f4c9b 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -1,51 +1,129 @@ -immutable ArrayPartition{T} <: AbstractVector{Any} - x::T +struct ArrayPartition{T,S<:Tuple} <: AbstractVector{T} + x::S end + +## constructors + ArrayPartition(x...) = ArrayPartition((x...)) -function ArrayPartition{T,T2<:Tuple}(x::T2,::Type{Val{T}}=Val{false}) - if T - return ArrayPartition{T2}(((copy(a) for a in x)...)) + +function ArrayPartition(x::S, ::Type{Val{copy}}=Val{false}) where {S<:Tuple,copy} + T = promote_type(eltype.(x)...) + + if copy + return ArrayPartition{T,S}(copy.(x)) else - return ArrayPartition{T2}((x...)) + return ArrayPartition{T,S}(x) end end -Base.similar(A::ArrayPartition) = ArrayPartition((similar.(A.x))...) -Base.similar(A::ArrayPartition, dims::Tuple) = ArrayPartition((similar.(A.x))...) # Ignore dims / indices since it's a vector -Base.similar{T}(A::ArrayPartition, ::Type{T}) = ArrayPartition(similar.(A.x, T)...) -Base.similar{T}(A::ArrayPartition, ::Type{T}, dims::Tuple) = ArrayPartition(similar.(A.x, T, dims)...) -Base.zeros(A::ArrayPartition) = ArrayPartition((zeros(x) for x in A.x)...) -Base.zeros(A::ArrayPartition, dims::Tuple) = ArrayPartition((zeros.(A.x))...) # Ignore dims / indices since it's a vector -Base.zeros{T}(A::ArrayPartition, ::Type{T}) = ArrayPartition(zeros.(A.x, T)...) -Base.zeros{T}(A::ArrayPartition, ::Type{T}, dims::Tuple) = ArrayPartition(zeros.(A.x, T, dims)...) +## similar array partitions -Base.copy(A::ArrayPartition) = Base.similar(A) -Base.eltype(A::ArrayPartition) = eltype(A.x[1]) +Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(A.x)) -# Special to work with units -function Base.ones(A::ArrayPartition) - B = similar(A::ArrayPartition) - for i in eachindex(A.x) - B.x[i] .= eltype(A.x[i])(one(first(A.x[i]))) - end - B -end - -Base.:+(A::ArrayPartition, B::ArrayPartition) = - ArrayPartition((x .+ y for (x,y) in zip(A.x,B.x))...) -Base.:+(A::Number, B::ArrayPartition) = ArrayPartition((A .+ x for x in B.x)...) -Base.:+(A::ArrayPartition, B::Number) = ArrayPartition((B .+ x for x in A.x)...) -Base.:-(A::ArrayPartition, B::ArrayPartition) = - ArrayPartition((x .- y for (x,y) in zip(A.x,B.x))...) -Base.:-(A::Number, B::ArrayPartition) = ArrayPartition((A .- x for x in B.x)...) -Base.:-(A::ArrayPartition, B::Number) = ArrayPartition((x .- B for x in A.x)...) -Base.:*(A::Number, B::ArrayPartition) = ArrayPartition((A .* x for x in B.x)...) -Base.:*(A::ArrayPartition, B::Number) = ArrayPartition((x .* B for x in A.x)...) -Base.:/(A::ArrayPartition, B::Number) = ArrayPartition((x ./ B for x in A.x)...) -Base.:\(A::Number, B::ArrayPartition) = ArrayPartition((x ./ A for x in B.x)...) - -@inline function Base.getindex( A::ArrayPartition,i::Int) - @boundscheck i > length(A) && throw(BoundsError("Index out of bounds")) +# ignore dims since array partitions are vectors +Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A) + +# similar array partition of common type +@generated function Base.similar(A::ArrayPartition, ::Type{T}) where {T} + N = npartitions(A) + expr = :(similar(A.x[i], T)) + + build_arraypartition(N, expr) +end + +# ignore dims since array partitions are vectors +Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(A, T) + +# similar array partition with different types +@generated function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S}, + R::Vararg{Type}) where {T,S} + N = npartitions(A) + N != length(R) + 2 && + throw(DimensionMismatch("number of types must be equal to number of partitions")) + + types = (T, S, parameter.(R)) # new types + expr = :(similar(A.x[i], ($types)[i])) + + build_arraypartition(N, expr) +end + +Base.copy(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(copy.(A.x)) + +## zeros + +Base.zeros(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zeros.(A.x)) + +# ignore dims since array partitions are vectors +Base.zeros(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zeros(A) + +## ones + +# special to work with units +@generated function Base.ones(A::ArrayPartition) + N = npartitions(A) + + expr = :(fill!(similar(A.x[i]), oneunit(eltype(A.x[i])))) + + build_arraypartition(N, expr) +end + +# ignore dims since array partitions are vectors +Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A) + +## vector space operations + +for op in (:+, :-) + @eval begin + @generated function Base.$op(A::ArrayPartition, B::ArrayPartition) + N = npartitions(A, B) + expr = :($($op).(A.x[i], B.x[i])) + + build_arraypartition(N, expr) + end + + @generated function Base.$op(A::ArrayPartition, B::Number) + N = npartitions(A) + expr = :($($op).(A.x[i], B)) + + build_arraypartition(N, expr) + end + + @generated function Base.$op(A::Number, B::ArrayPartition) + N = npartitions(B) + expr = :($($op).(A, B.x[i])) + + build_arraypartition(N, expr) + end + end +end + +for op in (:*, :/) + @eval @generated function Base.$op(A::ArrayPartition, B::Number) + N = npartitions(A) + expr = :($($op).(A.x[i], B)) + + build_arraypartition(N, expr) + end +end + +@generated function Base.:*(A::Number, B::ArrayPartition) + N = npartitions(B) + expr = :((*).(A, B.x[i])) + + build_arraypartition(N, expr) +end + +@generated function Base.:\(A::Number, B::ArrayPartition) + N = npartitions(B) + expr = :((/).(B.x[i], A)) + + build_arraypartition(N, expr) +end + +## indexing + +@inline function Base.getindex(A::ArrayPartition, i::Int) + @boundscheck checkbounds(A, i) @inbounds for j in 1:length(A.x) i -= length(A.x[j]) if i <= 0 @@ -53,9 +131,28 @@ Base.:\(A::Number, B::ArrayPartition) = ArrayPartition((x ./ A for x in B.x)...) end end end -Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)] + +""" + getindex(A::ArrayPartition, i::Int, j...) + +Return the entry at index `j...` of the `i`th partition of `A`. +""" +@inline function Base.getindex(A::ArrayPartition, i::Int, j...) + @boundscheck 0 < i <= length(A.x) || throw(BoundsError(A.x, i)) + @inbounds b = A.x[i] + @boundscheck checkbounds(b, j...) + @inbounds return b[j...] +end + +""" + getindex(A::ArrayPartition, ::Colon) + +Return vector with all elements of array partition `A`. +""" +Base.getindex(A::ArrayPartition{T,S}, ::Colon) where {T,S} = T[a for a in Chain(A.x)] + @inline function Base.setindex!(A::ArrayPartition, v, i::Int) - @boundscheck i > length(A) && throw(BoundsError("Index out of bounds")) + @boundscheck checkbounds(A, i) @inbounds for j in 1:length(A.x) i -= length(A.x[j]) if i <= 0 @@ -64,28 +161,41 @@ Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)] end end end -Base.getindex( A::ArrayPartition, i::Int...) = A.x[i[1]][Base.tail(i)...] -Base.setindex!(A::ArrayPartition, v, i::Int...) = A.x[i[1]][Base.tail(i)...]=v -function recursivecopy!(A::ArrayPartition,B::ArrayPartition) - for (a,b) in zip(A.x,B.x) - copy!(a,b) +""" + setindex!(A::ArrayPartition, v, i::Int, j...) + +Set the entry at index `j...` of the `i`th partition of `A` to `v`. +""" +@inline function Base.setindex!(A::ArrayPartition, v, i::Int, j...) + @boundscheck 0 < i <= length(A.x) || throw(BoundsError(A.x, i)) + @inbounds b = A.x[i] + @boundscheck checkbounds(b, j...) + @inbounds b[j...] = v +end + +## recursive methods + +function recursivecopy!(A::ArrayPartition, B::ArrayPartition) + for (a, b) in zip(A.x, B.x) + recursivecopy!(a, b) end end recursive_one(A::ArrayPartition) = recursive_one(first(A.x)) + recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x)) -Base.zero(A::ArrayPartition) = zero(first(A.x)) -Base.first(A::ArrayPartition) = first(first(A.x)) -Base.start(A::ArrayPartition) = start(chain(A.x...)) -Base.next(A::ArrayPartition,state) = next(chain(A.x...),state) -Base.done(A::ArrayPartition,state) = done(chain(A.x...),state) +## iteration + +Base.start(A::ArrayPartition) = start(Chain(A.x)) +Base.next(A::ArrayPartition,state) = next(Chain(A.x),state) +Base.done(A::ArrayPartition,state) = done(Chain(A.x),state) Base.length(A::ArrayPartition) = sum((length(x) for x in A.x)) Base.size(A::ArrayPartition) = (length(A),) -Base.isempty(A::ArrayPartition) = (length(A) == 0) -Base.eachindex(A::ArrayPartition) = Base.OneTo(length(A)) + +## display # restore the type rendering in Juno Juno.@render Juno.Inline x::ArrayPartition begin @@ -97,23 +207,83 @@ Base.show(io::IO,A::ArrayPartition) = (Base.show.(io,A.x); nothing) Base.display(A::ArrayPartition) = (println(summary(A));display.(A.x);nothing) Base.display(io::IO,A::ArrayPartition) = (println(summary(A));display.(io,A.x);nothing) -add_idxs(x,expr) = expr -add_idxs{T<:ArrayPartition}(::Type{T},expr) = :($(expr).x[i]) +## broadcasting + +Base.Broadcast._containertype(::Type{<:ArrayPartition}) = ArrayPartition +Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type) = ArrayPartition +Base.Broadcast.promote_containertype(::Type, ::Type{ArrayPartition}) = ArrayPartition +Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type{ArrayPartition}) = ArrayPartition +Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type{Array}) = ArrayPartition +Base.Broadcast.promote_containertype(::Type{Array}, ::Type{ArrayPartition}) = ArrayPartition + +@generated function Base.Broadcast.broadcast_c(f, ::Type{ArrayPartition}, as...) + # common number of partitions + N = npartitions(as...) -@generated function Base.broadcast!(f,A::ArrayPartition,B...) - exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...) - :(for i in eachindex(A.x) - broadcast!(f,A.x[i],$(exs...)) - end) + # broadcast partitions separately + expr = :(broadcast(f, + # index partitions + $((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d]) + for d in 1:length(as))...))) + + build_arraypartition(N, expr) end -@generated function Base.broadcast(f,B::Union{Number,ArrayPartition}...) - arr_idx = 0 - for (i,b) in enumerate(B) - if b <: ArrayPartition - arr_idx = i - break +@generated function Base.Broadcast.broadcast_c!(f, ::Type{ArrayPartition}, ::Type, + dest::ArrayPartition, as...) + # common number of partitions + N = npartitions(dest, as...) + + # broadcast partitions separately + quote + for i in 1:$N + broadcast!(f, dest.x[i], + # index partitions + $((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d]) + for d in 1:length(as))...)) + end + dest + end +end + +## utils + +""" + build_arraypartition(N::Int, expr::Expr) + +Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of +`expr` with variable `i` set to the partition index in the range of 1 to `N`. + +This can help to write a type-stable method in cases in which the correct return type can +can not be inferred for a simpler implementation with generators. +""" +function build_arraypartition(N::Int, expr::Expr) + quote + @Base.nexprs $N i->(A_i = $expr) + partitions = @Base.ncall $N tuple i->A_i + ArrayPartition(partitions) end - end - :(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A) end + +""" + npartitions(A...) + +Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are +`ArrayPartitions` with a different number of partitions. +""" +npartitions(A) = 0 +npartitions(::Type{ArrayPartition{T,S}}) where {T,S} = length(S.parameters) +npartitions(A, B...) = common_number(npartitions(A), npartitions(B...)) + +common_number(a, b) = + a == 0 ? b : + (b == 0 ? a : + (a == b ? a : + throw(DimensionMismatch("number of partitions must be equal")))) + +""" + parameter(::Type{T}) + +Return type `T` of singleton. +""" +parameter(::Type{T}) where {T} = T diff --git a/test/partitions_test.jl b/test/partitions_test.jl index d262f66e..5f94ff3b 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -24,12 +24,38 @@ p .= (*).(p,a) p .= (*).(p,p2) K = (*).(p,p2) +## inference tests + x = ArrayPartition([1, 2], [3.0, 4.0]) + +# similar partitions @inferred(similar(x)) @inferred(similar(x, (2, 2))) -@test_broken @inferred(similar(x, Int, (2, 2))) -@test_broken @inferred(similar(x, (Int, Float64), (2, 2))) +@inferred(similar(x, Int)) +@inferred(similar(x, Int, (2, 2))) +@inferred(similar(x, Int, Float64)) + +# zeros +@inferred(zeros(x)) +@inferred(zeros(x, (2,2))) + +# ones +@inferred(ones(x)) +@inferred(ones(x, (2,2))) + +# vector space calculations +@inferred(x+5) +@inferred(5+x) +@inferred(x-5) +@inferred(5-x) +@inferred(x*5) +@inferred(5*x) +@inferred(x/5) +@inferred(5\x) +@inferred(x+x) +@inferred(x-x) +# broadcasting _scalar_op(y) = y + 1 # Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function: _broadcast_wrapper(y) = _scalar_op.(y) From 5352af91a8ad85a63695b6077dec5452ad9b332a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 16 Aug 2017 01:10:37 +0200 Subject: [PATCH 2/3] Add type-stable first and last --- src/array_partition.jl | 4 ++++ test/partitions_test.jl | 45 +++++++++++++++++++++++------------------ 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 8d9f4c9b..558f2916 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -195,6 +195,10 @@ Base.done(A::ArrayPartition,state) = done(Chain(A.x),state) Base.length(A::ArrayPartition) = sum((length(x) for x in A.x)) Base.size(A::ArrayPartition) = (length(A),) +# redefine first and last to avoid slow and not type-stable indexing +Base.first(A::ArrayPartition) = first(first(A.x)) +Base.last(A::ArrayPartition) = last(last(A.x)) + ## display # restore the type rendering in Juno diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 5f94ff3b..62a9c6de 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -29,35 +29,40 @@ K = (*).(p,p2) x = ArrayPartition([1, 2], [3.0, 4.0]) # similar partitions -@inferred(similar(x)) -@inferred(similar(x, (2, 2))) -@inferred(similar(x, Int)) -@inferred(similar(x, Int, (2, 2))) -@inferred(similar(x, Int, Float64)) +@inferred similar(x) +@inferred similar(x, (2, 2)) +@inferred similar(x, Int) +@inferred similar(x, Int, (2, 2)) +@inferred similar(x, Int, Float64) # zeros -@inferred(zeros(x)) -@inferred(zeros(x, (2,2))) +@inferred zeros(x) +@inferred zeros(x, (2,2)) +@inferred zero(x) # ones -@inferred(ones(x)) -@inferred(ones(x, (2,2))) +@inferred ones(x) +@inferred ones(x, (2,2)) # vector space calculations -@inferred(x+5) -@inferred(5+x) -@inferred(x-5) -@inferred(5-x) -@inferred(x*5) -@inferred(5*x) -@inferred(x/5) -@inferred(5\x) -@inferred(x+x) -@inferred(x-x) +@inferred x+5 +@inferred 5+x +@inferred x-5 +@inferred 5-x +@inferred x*5 +@inferred 5*x +@inferred x/5 +@inferred 5\x +@inferred x+x +@inferred x-x + +# indexing +@inferred first(x) +@inferred last(x) # broadcasting _scalar_op(y) = y + 1 # Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function: _broadcast_wrapper(y) = _scalar_op.(y) # Issue #8 -@inferred(_broadcast_wrapper(x)) +@inferred _broadcast_wrapper(x) From d5b657c16e50f52ec7e4e50eddd078e0ac4a88a4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 16 Aug 2017 03:26:50 +0200 Subject: [PATCH 3/3] Define recursive_eltype --- src/array_partition.jl | 6 ++++-- test/partitions_test.jl | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 558f2916..d662cf58 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -182,10 +182,12 @@ function recursivecopy!(A::ArrayPartition, B::ArrayPartition) end end -recursive_one(A::ArrayPartition) = recursive_one(first(A.x)) - recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x)) +# note: consider only first partition for recursive one and eltype +recursive_one(A::ArrayPartition) = recursive_one(first(A.x)) +recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x)) + ## iteration Base.start(A::ArrayPartition) = start(Chain(A.x)) diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 62a9c6de..18b6a482 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -60,6 +60,11 @@ x = ArrayPartition([1, 2], [3.0, 4.0]) @inferred first(x) @inferred last(x) +# recursive +@inferred recursive_mean(x) +@inferred recursive_one(x) +@inferred recursive_eltype(x) + # broadcasting _scalar_op(y) = y + 1 # Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function: