-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ChainRules support #71
Conversation
Pull Request Test Coverage Report for Build 8870522033Details
💛 - Coveralls |
This comment was marked as resolved.
This comment was marked as resolved.
This looks great, I think I will be able to remove some of the custom handling I had in Lux for this |
Fantastic. Thanks for looking! |
Do you plan to capture ForwardDiff calls as well? I was unsure how to capture them at the Node constants level, for Lux I handled them at the parameters level https://github.com/LuxDL/Lux.jl/blob/main/ext/LuxDynamicExpressionsForwardDiffExt.jl#L8-L52 |
I would be very happy to have ForwardDiff support for tree constants. For my own use-cases its lower on the priority list, so not sure when I'll get to it. The
Nice! I'm not sure how to translate this but let me know if you'd be open to moving it over here. Not sure how much work it would be though.
In the DynamicExpressions.jl/ext/DynamicExpressionsOptimExt.jl Lines 91 to 94 in 27b6199
Then I can update all the parameters by DynamicExpressions.jl/ext/DynamicExpressionsOptimExt.jl Lines 114 to 117 in 27b6199
The nice part about this is that it also works for |
Yes, I am definitely open to moving them here. What I meant with capturing them is how to define the dispatch. For eg, in Lux since I keep the parameters extracted in a vector so it is simple enough to write It is possible to do it here DynamicExpressions.jl/ext/DynamicExpressionsOptimExt.jl Lines 38 to 45 in 27b6199
It is also possible that ForwardDiff might be efficient enough without this special handling, given that you mention FiniteDifferences is already fast. |
911a7e7
to
e0344a1
Compare
e0344a1
to
b688654
Compare
I see, thanks. Seems a bit trickier. Will think more... |
Implements
ChainRulesCore.rrule
foreval_tree_array
for thetree
and theX
argument.For the
tree
argument I had to implement something custom becauseChainRulesCore.Tangent
doesn't support recursive types. To get around this I implementwhere
gradient
is a vector gradient of the constants in the tree in the usual depth-first order. It has some of theAbstractTangent
interface implemented (as much as makes sense).However this probably requires some care in downstream uses because it's not an array.
@avik-pal perhaps this is useful for the Lux.jl extension? (Would love to hear what you think of this PR, btw, given your experience in this area)
TODO:
Optim.optimize
extension to use Zygote AD with this interface. Or at least be compatible with user-passed gradients that return aNodeTangent
.