Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stack(array_of_arrays) #43334

Merged
merged 23 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ New library functions
inspecting which function `f` was originally wrapped. ([#42717])
* New `pkgversion(m::Module)` function to get the version of the package that loaded
a given module, similar to `pkgdir(m::Module)`. ([#45607])
* New function `stack(x)` which generalises `reduce(hcat, x::Vector{<:Vector})` to any dimensionality,
and allows any iterators of iterators. Method `stack(f, x)` generalises `mapreduce(f, hcat, x)` and
is efficient. ([#43334])

Library changes
---------------
Expand Down
230 changes: 230 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2605,6 +2605,236 @@ end
Ai
end

"""
stack(iter; [dims])

Combine a collection of arrays (or other iterable objects) of equal size
into one larger array, by arranging them along one or more new dimensions.

By default the axes of the elements are placed first,
giving `size(result) = (size(first(iter))..., size(iter)...)`.
This has the same order of elements as [`Iterators.flatten`](@ref)`(iter)`.

With keyword `dims::Integer`, instead the `i`th element of `iter` becomes the slice
[`selectdim`](@ref)`(result, dims, i)`, so that `size(result, dims) == length(iter)`.
In this case `stack` reverses the action of [`eachslice`](@ref) with the same `dims`.

The various [`cat`](@ref) functions also combine arrays. However, these all
extend the arrays' existing (possibly trivial) dimensions, rather than placing
the arrays along new dimensions.
They also accept arrays as separate arguments, rather than a single collection.

!!! compat "Julia 1.9"
This function requires at least Julia 1.9.

# Examples
```jldoctest
julia> vecs = (1:2, [30, 40], Float32[500, 600]);

julia> mat = stack(vecs)
2×3 Matrix{Float32}:
1.0 30.0 500.0
2.0 40.0 600.0

julia> mat == hcat(vecs...) == reduce(hcat, collect(vecs))
true

julia> vec(mat) == vcat(vecs...) == reduce(vcat, collect(vecs))
true

julia> stack(zip(1:4, 10:99)) # accepts any iterators of iterators
2×4 Matrix{Int64}:
1 2 3 4
10 11 12 13

julia> vec(ans) == collect(Iterators.flatten(zip(1:4, 10:99)))
true

julia> stack(vecs; dims=1) # unlike any cat function, 1st axis of vecs[1] is 2nd axis of result
3×2 Matrix{Float32}:
1.0 2.0
30.0 40.0
500.0 600.0

julia> x = rand(3,4);

julia> x == stack(eachcol(x)) == stack(eachrow(x), dims=1) # inverse of eachslice
true
```

Higher-dimensional examples:

```jldoctest
julia> A = rand(5, 7, 11);

julia> E = eachslice(A, dims=2); # a vector of matrices

julia> (element = size(first(E)), container = size(E))
(element = (5, 11), container = (7,))

julia> stack(E) |> size
(5, 11, 7)

julia> stack(E) == stack(E; dims=3) == cat(E...; dims=3)
true

julia> A == stack(E; dims=2)
true

julia> M = (fill(10i+j, 2, 3) for i in 1:5, j in 1:7);

julia> (element = size(first(M)), container = size(M))
(element = (2, 3), container = (5, 7))

julia> stack(M) |> size # keeps all dimensions
(2, 3, 5, 7)

julia> stack(M; dims=1) |> size # vec(container) along dims=1
(35, 2, 3)

julia> hvcat(5, M...) |> size # hvcat puts matrices next to each other
(14, 15)
```
"""
stack(iter; dims=:) = _stack(dims, iter)

"""
stack(f, args...; [dims])

Apply a function to each element of a collection, and `stack` the result.
Or to several collections, [`zip`](@ref)ped together.

The function should return arrays (or tuples, or other iterators) all of the same size.
These become slices of the result, each separated along `dims` (if given) or by default
along the last dimensions.

See also [`mapslices`](@ref), [`eachcol`](@ref).

# Examples
```jldoctest
julia> stack(c -> (c, c-32), "julia")
2×5 Matrix{Char}:
'j' 'u' 'l' 'i' 'a'
'J' 'U' 'L' 'I' 'A'

julia> stack(eachrow([1 2 3; 4 5 6]), (10, 100); dims=1) do row, n
vcat(row, row .* n, row ./ n)
end
2×9 Matrix{Float64}:
1.0 2.0 3.0 10.0 20.0 30.0 0.1 0.2 0.3
4.0 5.0 6.0 400.0 500.0 600.0 0.04 0.05 0.06
```
"""
stack(f, iter; dims=:) = _stack(dims, f(x) for x in iter)
stack(f, xs, yzs...; dims=:) = _stack(dims, f(xy...) for xy in zip(xs, yzs...))

_stack(dims::Union{Integer, Colon}, iter) = _stack(dims, IteratorSize(iter), iter)

_stack(dims, ::IteratorSize, iter) = _stack(dims, collect(iter))

function _stack(dims, ::Union{HasShape, HasLength}, iter)
S = @default_eltype iter
T = S != Union{} ? eltype(S) : Any # Union{} occurs for e.g. stack(1,2), postpone the error
if isconcretetype(T)
_typed_stack(dims, T, S, iter)
else # Need to look inside, but shouldn't run an expensive iterator twice:
array = iter isa Union{Tuple, AbstractArray} ? iter : collect(iter)
isempty(array) && return _empty_stack(dims, T, S, iter)
T2 = mapreduce(eltype, promote_type, array)
_typed_stack(dims, T2, eltype(array), array)
end
end

function _typed_stack(::Colon, ::Type{T}, ::Type{S}, A, Aax=_iterator_axes(A)) where {T, S}
xit = iterate(A)
nothing === xit && return _empty_stack(:, T, S, A)
x1, _ = xit
ax1 = _iterator_axes(x1)
B = similar(_ensure_array(x1), T, ax1..., Aax...)
off = firstindex(B)
len = length(x1)
while xit !== nothing
x, state = xit
_stack_size_check(x, ax1)
copyto!(B, off, x)
off += len
xit = iterate(A, state)
end
B
end

_iterator_axes(x) = _iterator_axes(x, IteratorSize(x))
_iterator_axes(x, ::HasLength) = (OneTo(length(x)),)
_iterator_axes(x, ::IteratorSize) = axes(x)

# For some dims values, stack(A; dims) == stack(vec(A)), and the : path will be faster
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S} =
_typed_stack(dims, T, S, IteratorSize(S), A)
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasLength, A) where {T,S} =
_typed_stack(dims, T, S, HasShape{1}(), A)
function _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasShape{N}, A) where {T,S,N}
if dims == N+1
_typed_stack(:, T, S, A, (_vec_axis(A),))
else
_dim_stack(dims, T, S, A)
end
end
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::IteratorSize, A) where {T,S} =
_dim_stack(dims, T, S, A)

_vec_axis(A, ax=_iterator_axes(A)) = length(ax) == 1 ? only(ax) : OneTo(prod(length, ax; init=1))

@constprop :aggressive function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S}
xit = Iterators.peel(A)
nothing === xit && return _empty_stack(dims, T, S, A)
x1, xrest = xit
ax1 = _iterator_axes(x1)
N1 = length(ax1)+1
dims in 1:N1 || throw(ArgumentError(LazyString("cannot stack slices ndims(x) = ", N1-1, " along dims = ", dims)))

newaxis = _vec_axis(A)
outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1)
B = similar(_ensure_array(x1), T, outax...)

if dims == 1
_dim_stack!(Val(1), B, x1, xrest)
elseif dims == 2
_dim_stack!(Val(2), B, x1, xrest)
else
_dim_stack!(Val(dims), B, x1, xrest)
end
B
end

function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims}
before = ntuple(d -> Colon(), dims - 1)
after = ntuple(d -> Colon(), ndims(B) - dims)

i = firstindex(B, dims)
copyto!(view(B, before..., i, after...), x1)

for x in xrest
_stack_size_check(x, _iterator_axes(x1))
i += 1
@inbounds copyto!(view(B, before..., i, after...), x)
end
end

@inline function _stack_size_check(x, ax1::Tuple)
if _iterator_axes(x) != ax1
uax1 = map(UnitRange, ax1)
uaxN = map(UnitRange, axes(x))
throw(DimensionMismatch(
LazyString("stack expects uniform slices, got axes(x) == ", uaxN, " while first had ", uax1)))
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
end

_ensure_array(x::AbstractArray) = x
_ensure_array(x) = 1:0 # passed to similar, makes stack's output an Array

_empty_stack(_...) = throw(ArgumentError("`stack` on an empty collection is not allowed"))


## Reductions and accumulates ##

function isequal(A::AbstractArray, B::AbstractArray)
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ export
sortperm!,
sortslices,
dropdims,
stack,
step,
stride,
strides,
Expand Down
16 changes: 15 additions & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ See also [`Iterators.flatten`](@ref), [`Iterators.map`](@ref).

# Examples
```jldoctest
julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect
julia> Iterators.flatmap(n -> -n:2:n, 1:3) |> collect
9-element Vector{Int64}:
-1
1
Expand All @@ -1210,6 +1210,20 @@ julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect
-1
1
3

julia> stack(n -> -n:2:n, 1:3)
ERROR: DimensionMismatch: stack expects uniform slices, got axes(x) == (1:3,) while first had (1:2,)
[...]

julia> Iterators.flatmap(n -> (-n, 10n), 1:2) |> collect
4-element Vector{Int64}:
-1
10
-2
20

julia> ans == vec(stack(n -> (-n, 10n), 1:2))
true
```
"""
flatmap(f, c...) = flatten(map(f, c...))
Expand Down
1 change: 1 addition & 0 deletions doc/src/base/arrays.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Base.vcat
Base.hcat
Base.hvcat
Base.hvncat
Base.stack
Base.vect
Base.circshift
Base.circshift!
Expand Down
Loading