Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of SpecialFunctions.bessel... errors with SpecialFunctions >= 1 - support thunks? #873

Closed
devmotion opened this issue Jan 5, 2021 · 4 comments · Fixed by JuliaMath/SpecialFunctions.jl#308

Comments

@devmotion
Copy link
Collaborator

Recently, the ChainRules defintions for SpecialFunctions were moved from ChainRules to SpecialFunctions. During this update, also the definitions of the rules for bessel... were changed: previously the derivative with respect to the order was defined as NaN whereas now it is @thunk(error("not implemented")). This seems correct since the derivative exists but is just not implemented (see JuliaDiff/ChainRules.jl#292 (comment) for some discussion).

However, since Zygote unthunks everything when wrapping existing rrules even the gradient of SpecialFunctions.bessel... with respect to the second argument can't be computed anymore with SpecialFunctions >= 1:

julia> Zygote.gradient(1.0) do x
           besseli(1, x)
       end
ERROR: not implemented
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::SpecialFunctions.var"#27#30")() at /home/david/.julia/packages/SpecialFunctions/ERZOU/src/chainrules.jl:41
 [3] (::ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"})() at /home/david/.julia/packages/ChainRulesCore/cpHLu/src/differentials/thunks.jl:98
 [4] unthunk at /home/david/.julia/packages/ChainRulesCore/cpHLu/src/differentials/thunks.jl:99 [inlined]
 [5] (::ChainRulesCore.var"#11#12"{ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"}})() at /home/david/.julia/packages/ChainRulesCore/cpHLu/src/differentials/thunks.jl:40
 [6] (::ChainRulesCore.Thunk{ChainRulesCore.var"#11#12"{ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"}}})() at /home/david/.julia/packages/ChainRulesCore/cpHLu/src/differentials/thunks.jl:98
 [7] unthunk at /home/david/.julia/packages/ChainRulesCore/cpHLu/src/differentials/thunks.jl:99 [inlined]
 [8] *(::ChainRulesCore.Thunk{ChainRulesCore.var"#11#12"{ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"}}}, ::Float64) at /home/david/.julia/packages/ChainRulesCore/cpHLu/src/differential_arithmetic.jl:111
 [9] besseli_pullback at /home/david/.julia/packages/ChainRulesCore/cpHLu/src/rule_definition_tools.jl:188 [inlined]
 [10] ZBack at /home/david/.julia/packages/Zygote/ywhiG/src/compiler/chainrules.jl:77 [inlined]
 [11] #5 at ./REPL[5]:2 [inlined]
 [12] (::typeof((#5)))(::Float64) at /home/david/.julia/packages/Zygote/ywhiG/src/compiler/interface2.jl:0
 [13] (::Zygote.var"#41#42"{typeof((#5))})(::Float64) at /home/david/.julia/packages/Zygote/ywhiG/src/compiler/interface.jl:40
 [14] gradient(::Function, ::Float64) at /home/david/.julia/packages/Zygote/ywhiG/src/compiler/interface.jl:49
 [15] top-level scope at REPL[5]:1

whereas before (SpecialFunctions 0.10):

julia> Zygote.gradient(1.0) do x
           besseli(1, x)
       end
(0.7009067737595233,)

It seems the correct way to deal with this would be to add support for thunks to Zygote.

@sethaxen
Copy link
Contributor

sethaxen commented Jan 5, 2021

cc @oxinabox

@oxinabox
Copy link
Member

oxinabox commented Jan 5, 2021

Urg, yes.
Will be fixed by #603 which is feature-wise almost done but still needs a lot of rebasing and clean up.
Which should let us move the unthunking until the point where the gradient is consumed -- rather than the point where it is generated.
Thus if it is never consumed we wouldn't unthunk it and wouldn't get the error.

Without that it is a bit hard (though not impossible) to support thunks.
I think we could still do as the above but we would need to not only unthunk but also do the stuff that handles converting any otehr chainrules types that are unvield inthe Zygote types (Zero->Nothing, Composite to Tuple/NamedTuple) which we do have a function that does both that and unthunking, so maybe we can just move all that logic to where it is consumed.

But also SpecialFunctions should implement that gradient

@devmotion
Copy link
Collaborator Author

But also SpecialFunctions should implement that gradient

It should but the analytical expressions in e.g. JuliaDiff/ChainRules.jl#208 would require to add hypergeometric functions to SpecialFunctions, e.g., by depending on HypergeometricFunctions or new implementations in SpecialFunctions.

@devmotion
Copy link
Collaborator Author

This specific issue with SpecialFunctions is fixed now by using ChainRulesCore.@not_implemented in SpecialFunctions (JuliaMath/SpecialFunctions.jl#308).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants