Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: rules related to 3-arg * #412

Merged
merged 10 commits into from
Jan 17, 2022
Merged

RFC: rules related to 3-arg * #412

merged 10 commits into from
Jan 17, 2022

Conversation

mcabbott
Copy link
Member

JuliaLang/julia#37898 adds some 3- and 4-arg methods to *, do to things like contract matrices in the optimal order. That eventually calls pairwise * so shouldn't break AD.

It also fuses scalar-matrix-matrix multiplication, which involves calling mul!, and thus I think will break things. So this adds some rules!

The RFC is about exactly which functions to attach rules to, and with what signatures. The basic function is mat_vec_scalar(A, x, γ) = A * (x .* γ), but this has some methods which call _mat_vec_scalar which actually does the mutation. The minimal thing would be to attach the rule to _mat_vec_scalar, but it might be more efficient to call the fused rule in some other cases, rather than leaving AD to sort out the broadcast etc. (It's also not too late to fiddle with the Base PR, if someone has bright ideas.)

Tests will obviously fail on CI; they also fail locally for complex numbers right now.

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these. Have you got any examples of where having mat_mat_scalar rrule is better than having _mat_mat_scalar?

I think the signature also looks reasonable, but perhaps @willtebbutt might want to chip in?

src/rulesets/Base/arraymath.jl Outdated Show resolved Hide resolved
src/rulesets/Base/arraymath.jl Outdated Show resolved Hide resolved
src/rulesets/Base/arraymath.jl Outdated Show resolved Hide resolved
@nickrobinson251
Copy link
Contributor

Would this be the first time ChainRules.jl has rules for non-public functions (e.g. mat_mat_scalar or _mat_mat_scalar)?
How do we feel about rules for non-public functions?

@codecov-commenter
Copy link

codecov-commenter commented Jun 15, 2021

Codecov Report

Merging #412 (1ad05aa) into master (36508af) will decrease coverage by 1.42%.
The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #412      +/-   ##
==========================================
- Coverage   98.04%   96.62%   -1.43%     
==========================================
  Files          22       22              
  Lines        2306     2340      +34     
==========================================
  Hits         2261     2261              
- Misses         45       79      +34     
Impacted Files Coverage Δ
src/rulesets/Base/arraymath.jl 81.90% <0.00%> (-16.88%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 36508af...1ad05aa. Read the comment docs.

@ChrisRackauckas
Copy link
Member

Would this be the first time ChainRules.jl has rules for non-public functions (e.g. mat_mat_scalar or _mat_mat_scalar)? How do we feel about rules for non-public functions?

Well the parser forces these to be used by users in normal Julia, so it's pretty essential.

@mcabbott
Copy link
Member Author

mcabbott commented Dec 22, 2021

What I think Nick means is that the rules could instead be attached to *. The signatures you'd have to make sure to catch are these:

https://github.com/JuliaLang/julia/pull/37898/files#diff-a454fe10464022052956170246cb4be7e2e71bc517e02c2ec0a9548fb6aaaf29R1115-R1120

Edit -- One reason not to, BTW, is that it might make sense to send more cases to these functions. I forget where this came up but some combinations with I showed up in some rule I think, and:

julia> A = rand(100,100); x = rand(100); J = rand()*I;

julia> y = @btime $A * $I * $x;
  min 3.613 μs, mean 12.358 μs (3 allocations, 79.05 KiB)

julia> y ≈ @btime LinearAlgebra.mat_vec_scalar($A, $x, $I.λ)
  min 1.921 μs, mean 2.061 μs (1 allocation, 896 bytes)
true

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly cosmetic comments, LGTM generally. Thanks!

Just to make sure I understand the reason for not extending *: we might add more * methods that use mat_mat_scalar, which would require adding more * rules here. With the current implementation, adding rules would not be necessary.

Oh, one more thing: is the ReversePropagation integration test failure real?

src/rulesets/Base/arraymath.jl Outdated Show resolved Hide resolved
test/rulesets/Base/arraymath.jl Outdated Show resolved Hide resolved
test/rulesets/Base/arraymath.jl Outdated Show resolved Hide resolved
test_rrule(mat_mat_scalar, rand(T,4,4)' ⊢ rand(T,4,4), rand(T,4,4), rand(T))

test_rrule(mat_vec_scalar, rand(T,4,4), rand(T,4), rand(T))
test_rrule(mat_vec_scalar, rand(T,4,4), rand(T,4), 0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a test for the adjoint/transpose here as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I forgot this one, sorry. Other changes adopted.

test/rulesets/Base/arraymath.jl Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants