Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ForwardDiff support as weak dependency? #349

Closed
longemen3000 opened this issue Mar 1, 2023 · 5 comments
Closed

add ForwardDiff support as weak dependency? #349

longemen3000 opened this issue Mar 1, 2023 · 5 comments

Comments

@longemen3000
Copy link
Contributor

longemen3000 commented Mar 1, 2023

Now that there is support for weak dependencies. adding ForwardDiff support seems viable. this could directly use the already available definition of ChainRulesCore.frule present on the package

@jverzani
Copy link
Member

jverzani commented Mar 2, 2023

Yes, thanks for pointing this out!

@jverzani
Copy link
Member

jverzani commented Mar 3, 2023

I thought this was as straightforward as setting up this ForwardDiffExt file:

module ForwardDiffExt

using Roots
using ForwardDiff
using ForwardDiffChainRules
import CommonSolve: solve

@ForwardDiff_frule solve(ZP::ZeroProblem, M::Roots.AbstractUnivariateZeroMethod,
                         p::ForwardDiff.Dual)
@ForwardDiff_frule solve(ZP::ZeroProblem, M::Roots.AbstractUnivariateZeroMethod,
                         p::AbstractVector{<: ForwardDiff.Dual})

end

But I get an error when trying it out:

julia> f(x,p) = cos(x) - x*p; F(p) = solve(ZeroProblem(f, (0, pi/2)), Bisection(), p)
F (generic function with 1 method)

julia> ForwardDiff.derivative(F, 1.0)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:872
  iterate(::Union{LinRange, StepRangeLen}, ::Integer) at range.jl:872
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:712
  ...
Stacktrace:
 [1] indexed_iterate(I::Nothing, i::Int64)
   @ Base ./tuple.jl:91
 [2] solve(ZP::ZeroProblem{typeof(f), Tuple{Int64, Float64}}, M::Bisection, p::ForwardDiff.Dual{ForwardDiff.Tag{typeof(F), Float64}, Float64, 1})
   @ Main ~/.julia/packages/ForwardDiffChainRules/s5si2/src/ForwardDiffChainRules.jl:62
 [3] F(p::ForwardDiff.Dual{ForwardDiff.Tag{typeof(F), Float64}, Float64, 1})
   @ Main ./REPL[6]:1
 [4] derivative(f::typeof(F), x::Float64)
   @ ForwardDiff ~/.julia/packages/ForwardDiff/QdStj/src/derivative.jl:14
 [5] top-level scope
   @ REPL[7]:1

Do you have any insight into how to leverage the ChainRulesCore.frule already defined?

@devmotion
Copy link
Member

Some quick comments:

  • It would be good to rename the extension to something more unique since extensions of the same name currently break system image generation. A common approach is to use the package as a prefix, ie RootsForwardDiffExt
  • You must not load any other package apart from the parent package and the weak dependency/dependencies in an extension. So you would have to use e.g. import Roots: solve instead of import CommonSolve: solve (assuming you loaded CommonSolve.solve inside of Roots).
  • If ForwardDiffChainRules would be used in the extension, it seems it would have to be a weak dependency of the extension as well. In that case the rules would only be defined if users also load ForwardDiffChainRules (either im- or explicitly). I think you rather want the rules to be available as soon as ForwardDiff is loaded, so probably you don't want to use ForwardDiffChainRules but manually leverage the existing ChainRules definitions if possible.

@jverzani
Copy link
Member

jverzani commented Mar 3, 2023

Thanks! That is super helpful.

@jverzani
Copy link
Member

jverzani commented Mar 3, 2023

In #352 I have an implementation that seems to work. I couldn't figure out how to leverage the frule definition, but this seems to be giving the correct answers.

jverzani added a commit that referenced this issue Mar 7, 2023
* WIP

* WIP

* Add extension for ForwardDiff

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* cleanup

* add keyword parameter test

* only weakdep; thx!

* cleanup

---------

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants