Skip to content

Commit

Permalink
Customizable lazy fused broadcasting in pure Julia
Browse files Browse the repository at this point in the history
This patch represents the combined efforts of four individuals, over 60
commits, and an iterated design over (at least) three pull requests that
spanned nearly an entire year (closes #22063, #23692, #25377 by superceding
them).

This introduces a pure Julia data structure that represents a fused broadcast
expression.  For example, the expression `2 .* (x .+ 1)` lowers to:

```julia
julia> Meta.@lower 2 .* (x .+ 1)
:($(Expr(:thunk, CodeInfo(:(begin
      Core.SSAValue(0) = (Base.getproperty)(Base.Broadcast, :materialize)
      Core.SSAValue(1) = (Base.getproperty)(Base.Broadcast, :make)
      Core.SSAValue(2) = (Base.getproperty)(Base.Broadcast, :make)
      Core.SSAValue(3) = (Core.SSAValue(2))(+, x, 1)
      Core.SSAValue(4) = (Core.SSAValue(1))(*, 2, Core.SSAValue(3))
      Core.SSAValue(5) = (Core.SSAValue(0))(Core.SSAValue(4))
      return Core.SSAValue(5)
  end)))))
```

Or, slightly more readably as:

```julia
using .Broadcast: materialize, make
materialize(make(*, 2, make(+, x, 1)))
```

The `Broadcast.make` function serves two purposes. Its primary purpose is to
construct the `Broadcast.Broadcasted` objects that hold onto the function, the
tuple of arguments (potentially including nested `Broadcasted` arguments), and
sometimes a set of `axes` to include knowledge of the outer shape. The
secondary purpose, however, is to allow an "out" for objects that _don't_ want
to participate in fusion. For example, if `x` is a range in the above `2 .* (x
.+ 1)` expression, it needn't allocate an array and operate elementwise — it
can just compute and return a new range. Thus custom structures are able to
specialize `Broadcast.make(f, args...)` just as they'd specialize on `f`
normally to return an immediate result.

`Broadcast.materialize` is identity for everything _except_ `Broadcasted`
objects for which it allocates an appropriate result and computes the
broadcast. It does two things: it `initialize`s the outermost `Broadcasted`
object to compute its axes and then `copy`s it.

Similarly, an in-place fused broadcast like `y .= 2 .* (x .+ 1)` uses the exact
same expression tree to compute the right-hand side of the expression as above,
and then uses `materialize!(y, make(*, 2, make(+, x, 1)))` to `instantiate` the
`Broadcasted` expression tree and then `copyto!` it into the given destination.

All-together, this forms a complete API for custom types to extend and
customize the behavior of broadcast (fixes #22060). It uses the existing
`BroadcastStyle`s throughout to simplify dispatch on many arguments:

* Custom types can opt-out of broadcast fusion by specializing
  `Broadcast.make(f, args...)` or `Broadcast.make(::BroadcastStyle, f, args...)`.

* The `Broadcasted` object computes and stores the type of the combined
  `BroadcastStyle` of its arguments as its first type parameter, allowing for
  easy dispatch and specialization.

* Custom Broadcast storage is still allocated via `broadcast_similar`, however
  instead of passing just a function as a first argument, the entire
  `Broadcasted` object is passed as a final argument. This potentially allows
  for much more runtime specialization dependent upon the exact expression
  given.

* Custom broadcast implmentations for a `CustomStyle` are defined by
  specializing `copy(bc::Broadcasted{CustomStyle})` or
  `copyto!(dest::AbstractArray, bc::Broadcasted{CustomStyle})`.

* Fallback broadcast specializations for a given output object of type `Dest`
  (for the `DefaultArrayStyle` or another such style that hasn't implemented
  assignments into such an object) are defined by specializing
  `copyto(dest::Dest, bc::Broadcasted{Nothing})`.

As it fully supports range broadcasting, this now deprecates `(1:5) + 2` to
`.+`, just as had been done for all `AbstractArray`s in general.

As a first-mover proof of concept, LinearAlgebra uses this new system to
improve broadcasting over structured arrays. Before, broadcasting over a
structured matrix would result in a sparse array. Now, broadcasting over a
structured matrix will _either_ return an appropriately structured matrix _or_
a dense array. This does incur a type instability (in the form of a
discriminated union) in some situations, but thanks to type-based introspection
of the `Broadcasted` wrapper commonly used functions can be special cased to be
type stable.  For example:

```julia
julia> f(d) = round.(Int, d)
f (generic function with 1 method)

julia> @inferred f(Diagonal(rand(3)))
3×3 Diagonal{Int64,Array{Int64,1}}:
 0  ⋅  ⋅
 ⋅  0  ⋅
 ⋅  ⋅  1

julia> @inferred Diagonal(rand(3)) .* 3
ERROR: return type Diagonal{Float64,Array{Float64,1}} does not match inferred return type Union{Array{Float64,2}, Diagonal{Float64,Array{Float64,1}}}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope

julia> @inferred Diagonal(1:4) .+ Bidiagonal(rand(4), rand(3), 'U') .* Tridiagonal(1:3, 1:4, 1:3)
4×4 Tridiagonal{Float64,Array{Float64,1}}:
 1.30771  0.838589   ⋅          ⋅
 0.0      3.89109   0.0459757   ⋅
  ⋅       0.0       4.48033    2.51508
  ⋅        ⋅        0.0        6.23739
```

In addition to the issues referenced above, it fixes:

* Fixes #19313, #22053, #23445, and #24586: Literals are no longer treated
  specially in a fused broadcast; they're just arguments in a `Broadcasted`
  object like everything else.

* Fixes #21094: Since broadcasting is now represented by a pure Julia
  datastructure it can be created within `@generated` functions and serialized.

* Fixes #26097: The fallback destination-array specialization method of
  `copyto!` is specifically implemented as `Broadcasted{Nothing}` and will not
  be confused by `nothing` arguments.

* Fixes the broadcast-specific element of #25499: The default base broadcast
  implementation no longer depends upon `Base._return_type` to allocate its
  array (except in the empty or concretely-type cases). Note that the sparse
  implementation (#19595) is still dependent upon inference and is _not_ fixed.

* Fixes #25340: Functions are treated like normal values just like arguments
  and only evaluated once.

* Fixes #22255, and is performant with 12+ fused broadcasts. Okay, that one was
  fixed on master already, but this fixes it now, too.

* Fixes #25521.

* The performance of this patch has been thoroughly tested through its
  iterative development process in #25377. There remain [two classes of
  performance regressions](#25377) that Nanosoldier flagged.

* #25691: Propagation of constant literals sill lose their constant-ness upon
  going through the broadcast machinery. I believe quite a large number of
  functions would need to be marked as `@pure` to support this -- including
  functions that are intended to be specialized.

(For bookkeeping, this is the squashed version of the [teh-jn/lazydotfuse](JuliaLang/julia#25377)
branch as of a1d4e7ec9756ada74fb48f2c514615b9d981cf5c. Squashed and separated
out to make it easier to review and commit)

Co-authored-by: Tim Holy <tim.holy@gmail.com>
Co-authored-by: Jameson Nash <vtjnash@gmail.com>
Co-authored-by: Andrew Keller <ajkeller34@users.noreply.github.com>
  • Loading branch information
4 people committed Apr 23, 2018
1 parent 05a3c41 commit 0c4ec6e
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 166 deletions.
258 changes: 152 additions & 106 deletions src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ module HigherOrderFns

# This module provides higher order functions specialized for sparse arrays,
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
import Base: map, map!, broadcast, broadcast!
import Base: map, map!, broadcast, copy, copyto!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
using Base.Broadcast: BroadcastStyle
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
using LinearAlgebra

# This module is organized as follows:
# (0) Define BroadcastStyle rules and convenience types for dispatch
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
# map[!]/broadcast[!]'s purposes. The methods below are written against this interface.
# (2) Define entry points for map[!] (short children of _map_[not]zeropres!).
Expand All @@ -29,11 +30,79 @@ using LinearAlgebra
# (12) Define map[!] methods handling combinations of sparse and structured matrices.


# (0) BroadcastStyle rules and convenience types for dispatch

SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}

# broadcast container type promotion for combinations of sparse arrays and other types
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
const SPVM = Union{SparseVecStyle,SparseMatStyle}

# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
# SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
# Fall back to DefaultArrayStyle for higher dimensionality.
SparseVecStyle(::Val{0}) = SparseVecStyle()
SparseVecStyle(::Val{1}) = SparseVecStyle()
SparseVecStyle(::Val{2}) = SparseMatStyle()
SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
SparseMatStyle(::Val{0}) = SparseMatStyle()
SparseMatStyle(::Val{1}) = SparseMatStyle()
SparseMatStyle(::Val{2}) = SparseMatStyle()
SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()

# Tuples promote to dense
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}()
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
PromoteToSparse(::Val{0}) = PromoteToSparse()
PromoteToSparse(::Val{1}) = PromoteToSparse()
PromoteToSparse(::Val{2}) = PromoteToSparse()
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()

Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()

Broadcast.BroadcastStyle(::SPVM, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse()

Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray
# could report itself as a DefaultArrayStyle().
# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details
is_supported_sparse_broadcast() = true
is_supported_sparse_broadcast(::AbstractArray, rest...) = false
is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...)
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)

# Dispatch on broadcast operations by number of arguments
const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} =
Broadcasted{Style,Axes,F,Tuple{}}
const SpBroadcasted1{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat}} =
Broadcasted{Style,Axes,F,Args}
const SpBroadcasted2{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat,SparseVecOrMat}} =
Broadcasted{Style,Axes,F,Args}

# (1) The definitions below provide a common interface to sparse vectors and matrices
# sufficient for the purposes of map[!]/broadcast[!]. This interface treats sparse vectors
# as n-by-one sparse matrices which, though technically incorrect, is how broacast[!] views
# sparse vectors in practice.
SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}
@inline numrows(A::SparseVector) = A.n
@inline numrows(A::SparseMatrixCSC) = A.m
@inline numcols(A::SparseVector) = 1
Expand Down Expand Up @@ -85,18 +154,18 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...)
entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...))
indextypeC = _promote_indtype(A, Bs...)
C = _allocres(size(A), indextypeC, entrytypeC, maxnnzC)
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
_map_notzeropres!(f, fofzeros, C, A, Bs...)
end
# (3) broadcast[!] entry points
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
copy(bc::SpBroadcasted1) = _noshapecheck_map(bc.f, bc.args[1])

@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf
@inline function copyto!(C::SparseVecOrMat, bc::Broadcasted0{Nothing})
isempty(C) && return _finishempty!(C)
f = bc.f
fofnoargs = f()
if _iszero(fofnoargs) # f() is zero, so empty C
trimstorage!(C, 0)
Expand All @@ -109,19 +178,12 @@ broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
return C
end

# the following three similar defs are necessary for type stability in the mixed vector/matrix case
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
broadcast(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} =
_diffshape_broadcast(f, A, Bs...)
function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
indextypeC = _promote_indtype(A, Bs...)
entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...)
shapeC = to_shape(Base.Broadcast.combine_indices(A, Bs...))
entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...))
shapeC = to_shape(Base.Broadcast.combine_axes(A, Bs...))
maxnnzC = fpreszeros ? _checked_maxnnzbcres(shapeC, A, Bs...) : _densennz(shapeC)
C = _allocres(shapeC, indextypeC, entrytypeC, maxnnzC)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
Expand All @@ -141,6 +203,10 @@ end
@inline _aresameshape(A, B) = size(A) == size(B)
@inline _aresameshape(A, B, Cs...) = _aresameshape(A, B) ? _aresameshape(B, Cs...) : false
@inline _checksameshape(As...) = _aresameshape(As...) || throw(DimensionMismatch("argument shapes must match"))
@inline _all_args_isa(t::Tuple{Any}, ::Type{T}) where T = isa(t[1], T)
@inline _all_args_isa(t::Tuple{Any,Vararg{Any}}, ::Type{T}) where T = isa(t[1], T) & _all_args_isa(tail(t), T)
@inline _all_args_isa(t::Tuple{Broadcasted}, ::Type{T}) where T = _all_args_isa(t[1].args, T)
@inline _all_args_isa(t::Tuple{Broadcasted,Vararg{Any}}, ::Type{T}) where T = _all_args_isa(t[1].args, T) & _all_args_isa(tail(t), T)
@inline _densennz(shape::NTuple{1}) = shape[1]
@inline _densennz(shape::NTuple{2}) = shape[1] * shape[2]
_maxnnzfrom(shape::NTuple{1}, A) = nnz(A) * div(shape[1], A.n)
Expand Down Expand Up @@ -887,37 +953,56 @@ end

# (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices

# broadcast container type promotion for combinations of sparse arrays and other types
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
const SPVM = Union{SparseVecStyle,SparseMatStyle}
# broadcast entry points for combinations of sparse arrays and other (scalar) types
@inline function copy(bc::Broadcasted{<:SPVM})
bcf = flatten(bc)
return _copy(bcf.f, bcf.args...)
end

# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
# SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
# Fall back to DefaultArrayStyle for higher dimensionality.
SparseVecStyle(::Val{0}) = SparseVecStyle()
SparseVecStyle(::Val{1}) = SparseVecStyle()
SparseVecStyle(::Val{2}) = SparseMatStyle()
SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
SparseMatStyle(::Val{0}) = SparseMatStyle()
SparseMatStyle(::Val{1}) = SparseMatStyle()
SparseMatStyle(::Val{2}) = SparseMatStyle()
SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
_copy(f, args::SparseVector...) = _shapecheckbc(f, args...)
_copy(f, args::SparseMatrixCSC...) = _shapecheckbc(f, args...)
_copy(f, args::SparseVecOrMat...) = _diffshape_broadcast(f, args...)
# Otherwise, we incorporate scalars into the function and re-dispatch
function _copy(f, args...)
parevalf, passedargstup = capturescalars(f, args)
return _copy(parevalf, passedargstup...)
end

Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()
function _shapecheckbc(f, args...)
_aresameshape(args...) ? _noshapecheck_map(f, args...) : _diffshape_broadcast(f, args...)
end

# Tuples promote to dense
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}()
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# broadcast entry points for combinations of sparse arrays and other (scalar) types
function broadcast(f, ::SPVM, ::Nothing, ::Nothing, mixedargs::Vararg{Any,N}) where N
parevalf, passedargstup = capturescalars(f, mixedargs)
return broadcast(parevalf, passedargstup...)
@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{<:SPVM})
if bc.f === identity && bc isa SpBroadcasted1 && Base.axes(dest) == (A = bc.args[1]; Base.axes(A))
return copyto!(dest, A)
end
bcf = flatten(bc)
As = map(arg->Base.unalias(dest, arg), bcf.args)
return _copyto!(bcf.f, dest, As...)
end

@inline function _copyto!(f, dest, As::SparseVecOrMat...)
_aresameshape(dest, As...) && return _noshapecheck_map!(f, dest, As...)
Base.Broadcast.check_broadcast_axes(axes(dest), As...)
fofzeros = f(_zeros_eltypes(As...)...)
if _iszero(fofzeros)
return _broadcast_zeropres!(f, dest, As...)
else
return _broadcast_notzeropres!(f, fofzeros, dest, As...)
end
end

@inline function _copyto!(f, dest, args...)
# args contains nothing but SparseVecOrMat and scalars
# See below for capturescalars
parevalf, passedsrcargstup = capturescalars(f, args)
_copyto!(parevalf, dest, passedsrcargstup...)
end

struct CapturedScalars{F, Args, Order}
args::Args
end
# for broadcast! see (11)

# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
Expand All @@ -930,6 +1015,13 @@ end
return (parevalf, passedsrcargstup)
end
end
# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
@inline capturescalars(f, mixedargs::Tuple{SparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))

nonscalararg(::SparseVecOrMat) = true
nonscalararg(::Any) = false
Expand All @@ -942,11 +1034,17 @@ end
@inline function _capturescalars(arg, mixedargs...)
let (rest, f) = _capturescalars(mixedargs...)
if nonscalararg(arg)
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
return (arg, rest...), @inline function(head, tail...)
(head, f(tail...)...)
end # pass-through to broadcast
elseif scalarwrappedarg(arg)
return rest, (tail...) -> (arg[], f(tail...)...) # unwrap and add back scalararg after (in makeargs)
return rest, @inline function(tail...)
(arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
end # unwrap and add back scalararg after (in makeargs)
else
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
return rest, @inline function(tail...)
(arg, f(tail...)...)
end # add back scalararg after (in makeargs)
end
end
end
Expand All @@ -972,69 +1070,18 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
# vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.

struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
PromoteToSparse(::Val{0}) = PromoteToSparse()
PromoteToSparse(::Val{1}) = PromoteToSparse()
PromoteToSparse(::Val{2}) = PromoteToSparse()
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()

Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()

Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray
# could report itself as a DefaultArrayStyle().
# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details
is_supported_sparse_broadcast() = true
is_supported_sparse_broadcast(::AbstractArray, rest...) = false
is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...)
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)
function broadcast(f, s::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N}
if is_supported_sparse_broadcast(As...)
return broadcast(f, map(_sparsifystructured, As)...)
function copy(bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
if is_supported_sparse_broadcast(bcf.args...)
broadcast(bcf.f, map(_sparsifystructured, bcf.args)...)
else
return broadcast(f, Broadcast.ArrayConflict(), nothing, nothing, As...)
return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{2}}, bc))
end
end

# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
# we can handle it here, otherwise see below for the promotion machinery.
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A)
return copyto!(dest, A)
end
A′ = Base.unalias(dest, A)
Bs′ = map(B->Base.unalias(dest, B), Bs)
_aresameshape(dest, A′, Bs′...) && return _noshapecheck_map!(f, dest, A′, Bs′...)
Base.Broadcast.check_broadcast_indices(axes(dest), A′, Bs′...)
fofzeros = f(_zeros_eltypes(A′, Bs′...)...)
fpreszeros = _iszero(fofzeros)
fpreszeros ? _broadcast_zeropres!(f, dest, A′, Bs′...) :
_broadcast_notzeropres!(f, fofzeros, dest, A′, Bs′...)
return dest
end
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
broadcast!(parevalf, dest, passedsrcargstup...)
return dest
end
function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
return dest
@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
broadcast!(bcf.f, dest, map(_sparsifystructured, bcf.args)...)
end

_sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)
Expand All @@ -1047,8 +1094,7 @@ _sparsifystructured(x) = x


# (12) map[!] over combinations of sparse and structured matrices
SparseOrStructuredMatrix = Union{SparseMatrixCSC,StructuredMatrix}
map(f::Tf, A::StructuredMatrix) where {Tf} = _noshapecheck_map(f, _sparsifystructured(A))
SparseOrStructuredMatrix = Union{SparseMatrixCSC,LinearAlgebra.StructuredMatrix}
map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} =
(_checksameshape(A, Bs...); _noshapecheck_map(f, _sparsifystructured(A), map(_sparsifystructured, Bs)...))
map!(f::Tf, C::SparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} =
Expand Down
Loading

0 comments on commit 0c4ec6e

Please sign in to comment.