-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add missing overload for ZeroTangent
#256
Conversation
Codecov ReportAttention:
... and 9 files with indirect coverage changes 📢 Thoughts on this report? Let us know!. |
It seems a bit restrictive, indeed. But did you ran into an actual problem or is it just a potential issue? In the latter case I would wait with changes until one encounters this case in an application. I also thought that the ChainRulesTestUtils passed, so in case this is a bug it would be good to generally check this |
I linked an issue in KernelFunctions.jl that hits this, if the ZygoteDistancesExt is deactivated. It could be that the AD tests from KernelFunctions.jl are doing something nonstandard, though. |
Yes, the ChainRulesCore docs say so: https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/writing_good_rules.html#Ensure-your-pullback-can-accept-the-right-types . If Zygote is erroring, perhaps best to open an issue there or see if one is already opened. |
Based on the ChainRules docs, my impression is that there's no problem with the rules in Distances but possibly a Zygote issue. |
Thanks for investigating this! |
Zygote currently should be
My interpretation of the CRC docs is that this may be a grey area. Yes, rules should not have to handle zeros passed in. But there is no reliable way to tell a priori whether a given thunk will evaluate to a zero type, so that needs to be handled somehow. Rules may also want to handle the case where they get a mix of zeros and non-zeros for a composite return type, e.g. |
I suspect the problem here (and the |
Yes, the issue is absent when adding ChainRules@1.52.1, as mentioned in FluxML/Zygote.jl#1460 (comment) |
Basically, with ChainRules > v1.53.0: julia> f(x) = iszero(x) ? zero(x) : x
f (generic function with 1 method)
julia> using Zygote
julia> Zygote.gradient(f, 0.0)
(nothing,) whereas with ChainRules v1.52.1: julia> f(x) = iszero(x) ? zero(x) : x
f (generic function with 1 method)
julia> using Zygote
julia> Zygote.gradient(f, 0.0)
(0.0,) To my non-expert eyes it looks as if |
Zygote has been converting |
The failing examples all involve a broadcasting path with |
I'm not sure if we're referring to the same set of examples, but I've been focusing on the third one in FluxML/Zygote.jl#1460 (comment). That one shouldn't hit the Zygote broadcasting path which uses ForwardDiff, because it hits the CR rule at https://github.com/JuliaDiff/ChainRules.jl/blob/v1.55.0/src/rulesets/Base/mapreduce.jl#L76.
Zygote's own AD rules having a lot of edge cases and not handling data types off the beaten path nicely is well-known I think. That's why they're being slowly removed in favour of better-written rrules in ChainRules or elsewhere. |
This seems to be missing as the unthunked tangent can sometimes be a
ZeroTangent
. Does this make sense, @devmotion?