-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: rules related to 3-arg *
#412
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these. Have you got any examples of where having mat_mat_scalar
rrule is better than having _mat_mat_scalar
?
I think the signature also looks reasonable, but perhaps @willtebbutt might want to chip in?
Would this be the first time ChainRules.jl has rules for non-public functions (e.g. |
Codecov Report
@@ Coverage Diff @@
## master #412 +/- ##
==========================================
- Coverage 98.04% 96.62% -1.43%
==========================================
Files 22 22
Lines 2306 2340 +34
==========================================
Hits 2261 2261
- Misses 45 79 +34
Continue to review full report at Codecov.
|
Well the parser forces these to be used by users in normal Julia, so it's pretty essential. |
What I think Nick means is that the rules could instead be attached to Edit -- One reason not to, BTW, is that it might make sense to send more cases to these functions. I forget where this came up but some combinations with
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly cosmetic comments, LGTM generally. Thanks!
Just to make sure I understand the reason for not extending *
: we might add more *
methods that use mat_mat_scalar
, which would require adding more *
rules here. With the current implementation, adding rules would not be necessary.
Oh, one more thing: is the ReversePropagation integration test failure real?
test_rrule(mat_mat_scalar, rand(T,4,4)' ⊢ rand(T,4,4), rand(T,4,4), rand(T)) | ||
|
||
test_rrule(mat_vec_scalar, rand(T,4,4), rand(T,4), rand(T)) | ||
test_rrule(mat_vec_scalar, rand(T,4,4), rand(T,4), 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a test for the adjoint/transpose here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I forgot this one, sorry. Other changes adopted.
JuliaLang/julia#37898 adds some 3- and 4-arg methods to
*
, do to things like contract matrices in the optimal order. That eventually calls pairwise*
so shouldn't break AD.It also fuses scalar-matrix-matrix multiplication, which involves calling
mul!
, and thus I think will break things. So this adds some rules!The RFC is about exactly which functions to attach rules to, and with what signatures. The basic function is
mat_vec_scalar(A, x, γ) = A * (x .* γ)
, but this has some methods which call_mat_vec_scalar
which actually does the mutation. The minimal thing would be to attach the rule to_mat_vec_scalar
, but it might be more efficient to call the fused rule in some other cases, rather than leaving AD to sort out the broadcast etc. (It's also not too late to fiddle with the Base PR, if someone has bright ideas.)Tests will obviously fail on CI; they also fail locally for complex numbers right now.