Skip to content

Commit

Permalink
broadcast: remove unnecessary code duplication
Browse files Browse the repository at this point in the history
The problem it existed to solve is now handled directly by constant propagation.
  • Loading branch information
vtjnash committed Nov 1, 2018
1 parent 817f6fc commit 9e98386
Showing 1 changed file with 12 additions and 42 deletions.
54 changes: 12 additions & 42 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,17 @@ BroadcastStyle(a::AbstractArrayStyle, ::Style{Tuple}) = a
BroadcastStyle(::A, ::A) where A<:ArrayStyle = A()
BroadcastStyle(::ArrayStyle, ::ArrayStyle) = Unknown()
BroadcastStyle(::A, ::A) where A<:AbstractArrayStyle = A()
Base.@pure function BroadcastStyle(a::A, b::B) where {A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}} where {M,N}
if Base.typename(A).wrapper == Base.typename(B).wrapper
return A(_max(Val(M),Val(N)))
function BroadcastStyle(a::A, b::B) where {A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}} where {M,N}
if Base.typename(A) === Base.typename(B)
return A(Val(max(M, N)))
end
Unknown()
return Unknown()
end
# Any specific array type beats DefaultArrayStyle
BroadcastStyle(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a
BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a
BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
typeof(a)(_max(Val(M),Val(N)))
typeof(a)(Val(max(M, N)))

### Lazy-wrapper for broadcasting

Expand Down Expand Up @@ -201,22 +201,18 @@ Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{Bool}) =
similar(BitArray, axes(bc))

## Computing the result's axes. Most types probably won't need to specialize this.
broadcast_axes() = ()
broadcast_axes(A::Tuple) = (OneTo(length(A)),)
@inline broadcast_axes(A) = axes(A)
"""
Base.broadcast_axes(A)
Compute the axes for `A`.
This should only be specialized for objects that do not define [`axes`](@ref) but want to participate in broadcasting.
"""
broadcast_axes
@inline broadcast_axes(A) = axes(A)

@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes)
_axes(::Broadcasted, axes::Tuple) = axes
@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...)
_axes(bc::Broadcasted{Style{Tuple}}, ::Nothing) = (Base.OneTo(length(longest_tuple(nothing, bc.args))),)
_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = ()

BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
Expand Down Expand Up @@ -394,16 +390,6 @@ end
## Broadcasting utilities ##

## logic for deciding the BroadcastStyle
# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability
_max(V1::Val{Any}, V2::Val{Any}) = Val(Any)
_max(V1::Val{Any}, V2::Val{N}) where N = Val(Any)
_max(V1::Val{N}, V2::Val{Any}) where N = Val(Any)
_max(V1::Val, V2::Val) = __max(longest(ntuple(identity, V1), ntuple(identity, V2)))
__max(::NTuple{N,Bool}) where N = Val(N)
longest(t1::Tuple, t2::Tuple) = (true, longest(Base.tail(t1), Base.tail(t2))...)
longest(::Tuple{}, t2::Tuple) = (true, longest((), Base.tail(t2))...)
longest(t1::Tuple, ::Tuple{}) = (true, longest(Base.tail(t1), ())...)
longest(::Tuple{}, ::Tuple{}) = ()

# combine_styles operates on values (arbitrarily many)
combine_styles() = DefaultArrayStyle{0}()
Expand Down Expand Up @@ -961,28 +947,12 @@ end

## Tuple methods

@inline copy(bc::Broadcasted{Style{Tuple}}) =
tuplebroadcast(longest_tuple(nothing, bc.args), bc)
@inline tuplebroadcast(::NTuple{N,Any}, bc) where {N} = ntuple(k -> @inbounds(_broadcast_getindex(bc, k)), Val(N))
# This is a little tricky: find the longest tuple (first arg) within the list of arguments (second arg)
# Start with nothing as a placeholder and go until we find the first tuple in the argument list
longest_tuple(::Nothing, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(t[1], tail(t))
# Or recurse through nested broadcast expressions
longest_tuple(::Nothing, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(nothing, t[1].args), tail(t))
longest_tuple(::Nothing, t::Tuple) = longest_tuple(nothing, tail(t))
# And then compare it against all other tuples we find in the argument list or nested broadcasts
longest_tuple(l::Tuple, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(_longest_tuple(l, t[1]), tail(t))
longest_tuple(l::Tuple, t::Tuple) = longest_tuple(l, tail(t))
longest_tuple(l::Tuple, ::Tuple{}) = l
longest_tuple(l::Tuple, t::Tuple{Broadcasted}) = longest_tuple(l, t[1].args)
longest_tuple(l::Tuple, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(l, t[1].args), tail(t))
# Support only 1-tuples and N-tuples where there are no conflicts in N
_longest_tuple(A::Tuple{Any}, B::Tuple{Any}) = A
_longest_tuple(A::Tuple{Any}, B::NTuple{N,Any}) where N = B
_longest_tuple(A::NTuple{N,Any}, B::Tuple{Any}) where N = A
_longest_tuple(A::NTuple{N,Any}, B::NTuple{N,Any}) where N = A
@noinline _longest_tuple(A, B) =
throw(DimensionMismatch("tuples $A and $B could not be broadcast to a common size"))
@inline function copy(bc::Broadcasted{Style{Tuple}})
axes = broadcast_axes(bc)
length(axes) == 1 || throw(DimensionMismatch("tuple only supports one dimension"))
N = Val(length(axes[1]))
return ntuple(k -> @inbounds(_broadcast_getindex(bc, k)), N)
end

## scalar-range broadcast operations ##
# DefaultArrayStyle and \ are not available at the time of range.jl
Expand Down

0 comments on commit 9e98386

Please sign in to comment.