diff --git a/base/broadcast.jl b/base/broadcast.jl index d3567818b10e4..65ee317c5632f 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -134,17 +134,17 @@ BroadcastStyle(a::AbstractArrayStyle, ::Style{Tuple}) = a BroadcastStyle(::A, ::A) where A<:ArrayStyle = A() BroadcastStyle(::ArrayStyle, ::ArrayStyle) = Unknown() BroadcastStyle(::A, ::A) where A<:AbstractArrayStyle = A() -Base.@pure function BroadcastStyle(a::A, b::B) where {A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}} where {M,N} - if Base.typename(A).wrapper == Base.typename(B).wrapper - return A(_max(Val(M),Val(N))) +function BroadcastStyle(a::A, b::B) where {A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}} where {M,N} + if Base.typename(A) === Base.typename(B) + return A(Val(max(M, N))) end - Unknown() + return Unknown() end # Any specific array type beats DefaultArrayStyle BroadcastStyle(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = - typeof(a)(_max(Val(M),Val(N))) + typeof(a)(Val(max(M, N))) ### Lazy-wrapper for broadcasting @@ -201,9 +201,6 @@ Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{Bool}) = similar(BitArray, axes(bc)) ## Computing the result's axes. Most types probably won't need to specialize this. -broadcast_axes() = () -broadcast_axes(A::Tuple) = (OneTo(length(A)),) -@inline broadcast_axes(A) = axes(A) """ Base.broadcast_axes(A) @@ -211,12 +208,11 @@ Compute the axes for `A`. This should only be specialized for objects that do not define [`axes`](@ref) but want to participate in broadcasting. """ -broadcast_axes +@inline broadcast_axes(A) = axes(A) @inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes) _axes(::Broadcasted, axes::Tuple) = axes @inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...) -_axes(bc::Broadcasted{Style{Tuple}}, ::Nothing) = (Base.OneTo(length(longest_tuple(nothing, bc.args))),) _axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = () BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style() @@ -394,16 +390,6 @@ end ## Broadcasting utilities ## ## logic for deciding the BroadcastStyle -# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability -_max(V1::Val{Any}, V2::Val{Any}) = Val(Any) -_max(V1::Val{Any}, V2::Val{N}) where N = Val(Any) -_max(V1::Val{N}, V2::Val{Any}) where N = Val(Any) -_max(V1::Val, V2::Val) = __max(longest(ntuple(identity, V1), ntuple(identity, V2))) -__max(::NTuple{N,Bool}) where N = Val(N) -longest(t1::Tuple, t2::Tuple) = (true, longest(Base.tail(t1), Base.tail(t2))...) -longest(::Tuple{}, t2::Tuple) = (true, longest((), Base.tail(t2))...) -longest(t1::Tuple, ::Tuple{}) = (true, longest(Base.tail(t1), ())...) -longest(::Tuple{}, ::Tuple{}) = () # combine_styles operates on values (arbitrarily many) combine_styles() = DefaultArrayStyle{0}() @@ -961,28 +947,12 @@ end ## Tuple methods -@inline copy(bc::Broadcasted{Style{Tuple}}) = - tuplebroadcast(longest_tuple(nothing, bc.args), bc) -@inline tuplebroadcast(::NTuple{N,Any}, bc) where {N} = ntuple(k -> @inbounds(_broadcast_getindex(bc, k)), Val(N)) -# This is a little tricky: find the longest tuple (first arg) within the list of arguments (second arg) -# Start with nothing as a placeholder and go until we find the first tuple in the argument list -longest_tuple(::Nothing, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(t[1], tail(t)) -# Or recurse through nested broadcast expressions -longest_tuple(::Nothing, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(nothing, t[1].args), tail(t)) -longest_tuple(::Nothing, t::Tuple) = longest_tuple(nothing, tail(t)) -# And then compare it against all other tuples we find in the argument list or nested broadcasts -longest_tuple(l::Tuple, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(_longest_tuple(l, t[1]), tail(t)) -longest_tuple(l::Tuple, t::Tuple) = longest_tuple(l, tail(t)) -longest_tuple(l::Tuple, ::Tuple{}) = l -longest_tuple(l::Tuple, t::Tuple{Broadcasted}) = longest_tuple(l, t[1].args) -longest_tuple(l::Tuple, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(l, t[1].args), tail(t)) -# Support only 1-tuples and N-tuples where there are no conflicts in N -_longest_tuple(A::Tuple{Any}, B::Tuple{Any}) = A -_longest_tuple(A::Tuple{Any}, B::NTuple{N,Any}) where N = B -_longest_tuple(A::NTuple{N,Any}, B::Tuple{Any}) where N = A -_longest_tuple(A::NTuple{N,Any}, B::NTuple{N,Any}) where N = A -@noinline _longest_tuple(A, B) = - throw(DimensionMismatch("tuples $A and $B could not be broadcast to a common size")) +@inline function copy(bc::Broadcasted{Style{Tuple}}) + axes = broadcast_axes(bc) + length(axes) == 1 || throw(DimensionMismatch("tuple only supports one dimension")) + N = Val(length(axes[1])) + return ntuple(k -> @inbounds(_broadcast_getindex(bc, k)), N) +end ## scalar-range broadcast operations ## # DefaultArrayStyle and \ are not available at the time of range.jl