Skip to content

Commit

Permalink
Merge pull request #44061 from BSnelling/bes/collect_broadcasted_2
Browse files Browse the repository at this point in the history
Preserve shape when collecting broadcasted objects
  • Loading branch information
vtjnash committed May 27, 2022
2 parents 7e54f9a + 70fc3cd commit 938da26
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
17 changes: 15 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian()

Base.LinearIndices(bc::Broadcasted{<:Any,<:Tuple{Any}}) = LinearIndices(axes(bc))::LinearIndices{1}

Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
Base.ndims(bc::Broadcasted) = ndims(typeof(bc))
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N

Base.size(bc::Broadcasted) = map(length, axes(bc))
Expand All @@ -261,7 +261,20 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
return (bc[i], (s[1], newstate))
end

Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()
Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2))
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N

_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T))))
_maxndims(::Type{<:Tuple{T}}) where {T} = ndims(T)
_maxndims(::Type{<:Tuple{T}}) where {T<:Tuple} = _ndims(T)
function _maxndims(::Type{<:Tuple{T, S}}) where {T, S}
return T<:Tuple || S<:Tuple ? max(_ndims(T), _ndims(S)) : max(ndims(T), ndims(S))
end

_ndims(x) = ndims(x)
_ndims(::Type{<:Tuple}) = 1

Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown()

## Instantiation fills in the "missing" fields in Broadcasted.
Expand Down
33 changes: 33 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,39 @@ let
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
end

# issue 43847: collect preserves shape of broadcasted
let
bc = Broadcast.broadcasted(*, [1 2; 3 4], 2)
@test collect(Iterators.product(bc, bc)) == collect(Iterators.product(copy(bc), copy(bc)))

a1 = AD1(rand(2,3))
bc1 = Broadcast.broadcasted(*, a1, 2)
@test collect(Iterators.product(bc1, bc1)) == collect(Iterators.product(copy(bc1), copy(bc1)))

# using ndims of second arg
bc2 = Broadcast.broadcasted(*, 2, a1)
@test collect(Iterators.product(bc2, bc2)) == collect(Iterators.product(copy(bc2), copy(bc2)))

# >2 args
bc3 = Broadcast.broadcasted(*, a1, 3, a1)
@test collect(Iterators.product(bc3, bc3)) == collect(Iterators.product(copy(bc3), copy(bc3)))

# including a tuple and custom array type
bc4 = Broadcast.broadcasted(*, (1,2,3), AD1(rand(3)))
@test collect(Iterators.product(bc4, bc4)) == collect(Iterators.product(copy(bc4), copy(bc4)))

# testing ArrayConflict
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict}
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

@test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}()

# inference on nested
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))
bc_nest = Base.broadcasted(+, bc , bc)
@test @inferred(Base.IteratorSize(bc_nest)) === Base.HasShape{1}()
end

# issue #31295
let a = rand(5), b = rand(5), c = copy(a)
view(identity(a), 1:3) .+= view(b, 1:3)
Expand Down

0 comments on commit 938da26

Please sign in to comment.