Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize + for OneElement #260

Open
oxinabox opened this issue Jun 1, 2023 · 2 comments
Open

Optimize + for OneElement #260

oxinabox opened this issue Jun 1, 2023 · 2 comments

Comments

@oxinabox
Copy link
Member

oxinabox commented Jun 1, 2023

There is a significant optimization we can make to + where one of the arguments is a OneElement

I wrote out the general case for it here:
https://github.com/JuliaDiff/ChainRules.jl/pull/717/files#diff-3ebfe4c6177a89aaa1620d8565d89c551d882f7def080c63c343779c16366741R114-R124
though that uses ChainRulesCore.is_inplaceable_destination to workout if an array can be mutated which would need to be stripped before porting. Since there has been arguments against directly depending on CRC in this package in the past.

For this case on a 300x300 array it gives a 3x speedup (JuliaDiff/ChainRules.jl#717 (comment))

But there are even more optimal cases when both elements are OneElement (at cost of complete type stability)
And possibly other types too.

@jishnub
Copy link
Member

jishnub commented Jun 1, 2023

This certainly seems like a good optimization to include, but extending Base.:(+)(xs::AbstractArray, oe::OneElement) opens one up to ambiguities:

julia> Base.:(+)(xs::AbstractArray, oe::OneElement) = add(xs, oe)

julia> Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe)

julia> Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2)

julia> function add(a, b::FillArrays.OneElementVector)
           a[b.ind...] += b.val
           a
       end
add (generic function with 1 method)

julia> using StaticArrays

julia> SA[1,2] + OneElement(2,2)
ERROR: MethodError: +(::SVector{2, Int64}, ::OneElement{Int64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}) is ambiguous.

Candidates:
  +(xs::AbstractArray, oe::OneElement)
    @ Main REPL[5]:1
  +(a::StaticArray, b::AbstractArray)
    @ StaticArrays ~/.julia/packages/StaticArrays/J9itA/src/linalg.jl:14

Possible fix, define
  +(::StaticArray, ::OneElement)

Stacktrace:
 [1] top-level scope
   @ REPL[12]:1

Perhaps defining a broadcast style might be a better idea here?

@oxinabox
Copy link
Member Author

oxinabox commented Jun 1, 2023

Odds are high that such a broadcasting style would be similarly ambiguous, though more likely not with StaticArrays since that doesn't define a broadcasting style, but with something that does, like NamedDims.jl or DistributedArrays.jl
But maybe we could use the SparseArray broadcasting style.
Where probably the ambigs are already resolved, and a bunch of neat optimizations are already done.
Though its not obvious that that is designed to be extended.
https://github.com/JuliaSparse/SparseArrays.jl/blob/7df375c950477fe2122f29e2dde841b2c18e1dca/src/higherorderfns.jl

In general, ambiguities like this are inevitable in julia.
Its one of the main reasons weakdeps were introduced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants