Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChainRules support #71

Merged
merged 6 commits into from
Apr 28, 2024
Merged

Add ChainRules support #71

merged 6 commits into from
Apr 28, 2024

Conversation

MilesCranmer
Copy link
Member

@MilesCranmer MilesCranmer commented Apr 28, 2024

Implements ChainRulesCore.rrule for eval_tree_array for the tree and the X argument.

For the tree argument I had to implement something custom because ChainRulesCore.Tangent doesn't support recursive types. To get around this I implement

struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
    tree::N
    gradient::A
end

where gradient is a vector gradient of the constants in the tree in the usual depth-first order. It has some of the AbstractTangent interface implemented (as much as makes sense).

However this probably requires some care in downstream uses because it's not an array.

@avik-pal perhaps this is useful for the Lux.jl extension? (Would love to hear what you think of this PR, btw, given your experience in this area)


TODO:

  • (Maybe for later) rewrite Optim.optimize extension to use Zygote AD with this interface. Or at least be compatible with user-passed gradients that return a NodeTangent.

@coveralls
Copy link

coveralls commented Apr 28, 2024

Pull Request Test Coverage Report for Build 8870522033

Details

  • 27 of 27 (100.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.08%) to 94.789%

Totals Coverage Status
Change from base Build 8870212630: 0.08%
Covered Lines: 1637
Relevant Lines: 1727

💛 - Coveralls

This comment was marked as resolved.

@avik-pal
Copy link

@avik-pal perhaps this is useful for the Lux.jl extension? (Would love to hear what you think of this PR, btw, given your experience in this area)

This looks great, I think I will be able to remove some of the custom handling I had in Lux for this

@MilesCranmer
Copy link
Member Author

Fantastic. Thanks for looking!

@avik-pal
Copy link

Do you plan to capture ForwardDiff calls as well? I was unsure how to capture them at the Node constants level, for Lux I handled them at the parameters level https://github.com/LuxDL/Lux.jl/blob/main/ext/LuxDynamicExpressionsForwardDiffExt.jl#L8-L52

@MilesCranmer
Copy link
Member Author

MilesCranmer commented Apr 28, 2024

Do you plan to capture ForwardDiff calls as well?

I would be very happy to have ForwardDiff support for tree constants. For my own use-cases its lower on the priority list, so not sure when I'll get to it. The rrule is so far a priority for me as I want to have some Zygote-based AD optimization in SymbolicRegression.jl (right now its still finite difference-based – which surprisingly hasn't been so bad given it's low-dimensional, but can get a bit slow for very complex expressions).

I was unsure how to capture them at the Node constants level, for Lux I handled them at the parameters level https://github.com/LuxDL/Lux.jl/blob/main/ext/LuxDynamicExpressionsForwardDiffExt.jl#L8-L52

Nice! I'm not sure how to translate this but let me know if you'd be open to moving it over here. Not sure how much work it would be though.

capture them at the Node constants level

In the Optim.optimize what I will do is store a vector of Ref to the constant nodes, and just update them via dereferencing. (Not sure if this is what you were asking).

constant_refs = filter_map(
t -> t.degree == 0 && t.constant, t -> Ref(t), tree, Ref{typeof(tree)}
)
x0 = T[copy(t[].val) for t in constant_refs]

Then I can update all the parameters by

minimizer = Optim.minimizer(base_res)
@inbounds for i in eachindex(constant_refs, minimizer)
constant_refs[i][].val = minimizer[i]
end

The nice part about this is that it also works for GraphNode where you have multiple parents pointing to the same child – the filter_map will only return a single Ref to the child node, so you don't end up optimizing the same parameter from two elements.

@avik-pal
Copy link

Yes, I am definitely open to moving them here. What I meant with capturing them is how to define the dispatch. For eg, in Lux since I keep the parameters extracted in a vector so it is simple enough to write ::AbstractVector{<:Dual}. I am not sure how to detect ForwardDiff Duals "nicely" when they are part of the Nodes.

It is possible to do it here

function wrapped_f(args::Vararg{Any,M}) where {M}
first_args = args[1:(end - 1)]
x = last(args)
@inbounds for i in eachindex(constant_refs, x)
constant_refs[i][].val = x[i]
end
return @inline(f(first_args..., tree))
end
I think, because that code won't natively work with ForwardDiff.

It is also possible that ForwardDiff might be efficient enough without this special handling, given that you mention FiniteDifferences is already fast.

@MilesCranmer
Copy link
Member Author

I see, thanks. Seems a bit trickier. Will think more...

@MilesCranmer MilesCranmer merged commit 6211067 into master Apr 28, 2024
16 checks passed
@MilesCranmer MilesCranmer deleted the chainrules-core branch April 28, 2024 22:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants