diff --git a/src/collect.jl b/src/collect.jl index bceb1a68..368c4389 100644 --- a/src/collect.jl +++ b/src/collect.jl @@ -30,6 +30,12 @@ reshapestructarray(v::AbstractArray, d) = reshape(v, d) reshapestructarray(v::StructArray{T}, d) where {T} = StructArray{T}(map(x -> reshapestructarray(x, d), fieldarrays(v))) +function collect_empty_structarray(itr::T; initializer = default_initializer) where {T} + S = Core.Compiler.return_type(first, Tuple{T}) + res = initializer(S, (0,)) + _reshape(res, itr) +end + """ `collect_structarray(itr, fr=iterate(itr); initializer = default_initializer)` @@ -39,28 +45,29 @@ and size `d`. By default `initializer` returns a `StructArray` of `Array` but cu may be used. `fr` represents the moment in the iteration of `itr` from which to start collecting. """ collect_structarray(itr; initializer = default_initializer) = - collect_structarray(itr, iterate(itr); initializer = initializer) - -collect_structarray(itr, fr; initializer = default_initializer) = - collect_structarray(itr, fr, Base.IteratorSize(itr); initializer = initializer) + _collect_structarray(itr, Base.IteratorSize(itr); initializer = initializer) -collect_structarray(itr, ::Nothing; initializer = default_initializer) = - collect_empty_structarray(itr; initializer = initializer) - -function collect_empty_structarray(itr::T; initializer = default_initializer) where {T} - S = Core.Compiler.return_type(first, Tuple{T}) - res = initializer(S, (0,)) - _reshape(res, itr) +function _collect_structarray(itr, sz::Union{Base.HasShape, Base.HasLength}; + initializer = default_initializer) + len = length(itr) + elem = iterate(itr) + elem === nothing && return collect_empty_structarray(itr, initializer = initializer) + el, st = elem + S = typeof(el) + dest = initializer(S, (len,)) + dest[1] = el + v = collect_to_structarray!(dest, itr, 2, st) + _reshape(v, itr, sz) end -function collect_structarray(itr, elem, sz::Union{Base.HasShape, Base.HasLength}; - initializer = default_initializer) - el, i = elem +function _collect_structarray(itr, ::Base.SizeUnknown; initializer = default_initializer) + elem = iterate(itr) + elem === nothing && return collect_empty_structarray(itr, initializer = initializer) + el, st = elem S = typeof(el) - dest = initializer(S, (length(itr),)) + dest = initializer(S, (1,)) dest[1] = el - v = collect_to_structarray!(dest, itr, 2, i) - _reshape(v, itr, sz) + grow_to_structarray!(dest, itr, iterate(itr, st)) end function collect_to_structarray!(dest::AbstractArray, itr, offs, st) @@ -83,13 +90,6 @@ function collect_to_structarray!(dest::AbstractArray, itr, offs, st) return dest end -function collect_structarray(itr, elem, ::Base.SizeUnknown; initializer = default_initializer) - el, st = elem - dest = initializer(typeof(el), (1,)) - dest[1] = el - grow_to_structarray!(dest, itr, iterate(itr, st)) -end - function grow_to_structarray!(dest::AbstractArray, itr, elem = iterate(itr)) # collect to dest array, checking the type of each result. if a result does not # match, widen the result type and re-dispatch. @@ -152,16 +152,10 @@ function _append!!(dest::AbstractVector, itr, ::Union{Base.HasShape, Base.HasLen fr === nothing && return dest el, st = fr i = lastindex(dest) + 1 - if iscompatible(el, dest) - resize!(dest, length(dest) + n) - @inbounds dest[i] = el - return collect_to_structarray!(dest, itr, i + 1, st) - else - new = widenstructarray(dest, i, el) - resize!(new, length(dest) + n) - @inbounds new[i] = el - return collect_to_structarray!(new, itr, i + 1, st) - end + new = iscompatible(el, dest) ? dest : widenstructarray(dest, i, el) + resize!(new, length(dest) + n) + @inbounds new[i] = el + return collect_to_structarray!(new, itr, i + 1, st) end _append!!(dest::AbstractVector, itr, ::Base.SizeUnknown) = diff --git a/test/runtests.jl b/test/runtests.jl index 03d76c45..8a50c566 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -658,8 +658,7 @@ end ("SizeUnknown", () -> (x for x in itr if isodd(x.a))), # Broken due to https://github.com/JuliaArrays/StructArrays.jl/issues/100: # ("empty", (x for x in itr if false)), - # Broken due to https://github.com/JuliaArrays/StructArrays.jl/issues/99: - # ("stateful", () -> Iterators.Stateful(itr)), + ("stateful", () -> Iterators.Stateful(itr)), ] @testset "$destlabel $itrlabel" for (destlabel, dest) in dest_examples, (itrlabel, makeitr) in itr_examples