diff --git a/src/array_partition.jl b/src/array_partition.jl index c7d8c3ec..5d2b7c55 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -51,6 +51,10 @@ Base.zero(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zero.(A.x)) # ignore dims since array partitions are vectors Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A) +## Array + +Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:ArrayPartition}} = reduce(hcat,Array.(VA.u)) + ## ones # special to work with units diff --git a/src/zygote.jl b/src/zygote.jl index c427c3e7..7ec77d3c 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -14,3 +14,12 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...) end VA[i,j...],AbstractVectorOfArray_getindex_adjoint end + +ZygoteRules.@adjoint function ArrayPartition(x...) + function ArrayPartition_adjoint(_y) + y = Array(_y) + starts = vcat(0,cumsum(reduce(vcat,length.(x)))) + ntuple(i -> reshape(y[starts[i]+1:starts[i+1]],size(x[i])),length(x)) + end + ArrayPartition(x...),ArrayPartition_adjoint +end