Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Definition of pullback for logpdf is is overly optimistic #121

Open
willtebbutt opened this issue Oct 16, 2020 · 11 comments
Open

Definition of pullback for logpdf is is overly optimistic #121

willtebbutt opened this issue Oct 16, 2020 · 11 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Oct 16, 2020

This definition is very optimistic about the things that it thinks that it can handle.

In particular, it hijacks control away from this method in Stheno, and causes AD to do something entirely innappropriate in the sense that if this rule didn't exist, my code would work just fine. It causes similar problems to type piracy -- see this well-known ChainRules issue, which explains the core of the problem.

TLDR: defining rules for abstract types causes problems. Since we need to be able to work with abstract types at the minute, this means that you have to be really careful about the abstract types for which you implement rules.

@mohamed82008 any thoughts on how this implementation could be made less aggressive? It's currently blocking for Stheno-Turing integration, and is related to this issue.

@devmotion
Copy link
Member

One could keep an explicit list of supported and tested distributions in DistributionsAD, and only define it for those.

However, I think a proper fix would be to change the implementation of logpdf in Distributions such that the defaults do not use the in-place method logpdf! but just map with the out-of-place logpdf for single samples (that seems also much more robust in general). Then the Zygote definitions here could be removed, it seems.

In general, I try to not look too carefully at the implementations in DistributionsAD - there's type piracy all over the place, and as the example shows it can lead to all kinds of problems...

@willtebbutt
Copy link
Member Author

However, I think a proper fix would be to change the implementation of logpdf in Distributions such that the defaults do not use the in-place method logpdf! but just map with the out-of-place logpdf for single samples (that seems also much more robust in general). Then the Zygote definitions here could be removed, it seems.

This seems reasonable to me. Is this something that the Distributions folks might be receptive to do you think?

@devmotion
Copy link
Member

Is this something that the Distributions folks might be receptive to do you think?

I don't know but I assume they might be fine with it. I mean, it seems reasonable to me even without Zygote 🤷

@mohamed82008
Copy link
Member

@mohamed82008 any thoughts on how this implementation could be made less aggressive?

This is a hard one. Without making Distributions.jl compatible with Zygote, we need a catchall adjoint here to be an adjoint for Distributions.jl's catchall method. In your case, a workaround would be to define an adjoint for your method that calls pullback on another function name that has no adjoint.

Thinking about the bigger problem linked in that issue (I didn't read the whole issue so not sure if this has been discussed), I think we can essentially formalise the workaround used here by adding an additional dispatch layer that allows you to modify the "method-rrule matching rule". Imo, every method should have its own adjoint. If a more specific method was important enough to have in the forward pass then it makes sense that we may need to special case the reverse pass. But some times we may also not need that where a sufficiently generic reverse pass can be the adjoint of many forward methods. Imo this problem can be mostly solved by giving more control to developers and perhaps changing defaults. So now when I define a new Julia function, I can tell ChainRules please don't match my function using Julia's multiple dispatch criteria but treat it as its own thing. This can be literally implemented under the hood using the workaround proposed here, i.e. defining a "bridge rule" that calls another function with no adjoint methods.

We can also have an option at the rule definition site telling ChainRules not to match the rule to any forward method whose signature is more specific. Instead, only apply this adjoint to the "specific signatures" provided. This double-headed approach can let us mix and match between "method-rrule matching rules", sometimes using normal multiple dispatch where it seems to not break things and other times matching the exact signature. But more importantly, the approach proposed here lets the user opt out of the rules defined in ChainRules at the forward method definition site.

@mohamed82008
Copy link
Member

So now when I define a new function, if I suspect ChainRules is hurting my performance/correctness, I can either define a correct and performant rule for my method or opt out of ChainRules for this method. When defining a new type, it's more complicated because we automatically "sign in" with a few methods in the forward-pass. So perhaps we can also provide an opt-out mechanism based on types not just methods.

@willtebbutt
Copy link
Member Author

Thanks for your thoughts @mohamed82008 -- I think we're on the same page in regards to the problem.

I can either define a correct and performant rule for my method or opt out of ChainRules for this method.

Could you elaborate with some pseudo / example code or something? I'm struggling to understand what you're proposing, but would be keen to understand better.

I would generally be much more in favour of an opt-in mechanism. My reasoning for that we should view an inability of AD to automatically derive a rule as the norm, rather than the exception.

@mohamed82008
Copy link
Member

Could you elaborate with some pseudo / example code or something? I'm struggling to understand what you're proposing, but would be keen to understand better.

Defining a correct or performant rule is easy, just use ChainRules. Opting out can be done by overloading Zygote.has_chain_rrule (https://github.com/FluxML/Zygote.jl/blob/2fc416464ca4910d19618f589b0c93f595b16afb/src/compiler/chainrules.jl#L12) which I think should live in ChainRules anyways.

@mohamed82008
Copy link
Member

mohamed82008 commented Oct 19, 2020

An opt-in mechanism (opt out by default) would be hard to implement though. This because when we check for a rule, we check the concrete types to see if there is a method in rrule that can take these types as arguments. An opt-out by default approach means that any rrule defined on abstract types will never get a match, unless there is a rule for the particular concrete types that "forwards" to the abstract rrule definition. This doesn't feel Julian at all in the sense that we won't be taking advantage of the type system or multiple dispatch to simplify codes. For example, we will need rules for Float64, Float32, Float16, DoubleFloat, etc. I don't think that's feasible. A default opt-in approach with occasional opting out here or there for special types makes more sense to me from an implementation point of view.

@mohamed82008
Copy link
Member

I think what you are really advocating for here is rule definition for "narrow" abstract types, e.g. AbstractFloat instead of Real or DenseVector instead of AbstractVector. A trait-based rule matching would be useful here as well. So rules can make certain "assumptions" about their inputs, e.g. they are dense, sparse, O(1)-sized, etc. Then the rule checking needs to find an appropriate rule then. I think trait-based matching makes more sense than type-based matching or defining rules on concrete types only. Proper language-level support for traits may help here. See the discussion in JuliaLang/julia#37790.

@willtebbutt
Copy link
Member Author

Defining a correct or performant rule is easy, just use ChainRules. Opting out can be done by overloading Zygote.has_chain_rrule (https://github.com/FluxML/Zygote.jl/blob/2fc416464ca4910d19618f589b0c93f595b16afb/src/compiler/chainrules.jl#L12) which I think should live in ChainRules anyways.

Hmm yes. This could also be done by writing an rrule that returns nothing.

I mean, the best way to implement an opt-in mechanism is to just not define rrules for abstract types, and whenever you find that AD doesn't work for a particular type, define an rrule that calls the default function. For example, there are finitely-many Distribution subtypes types defined in Distributions.jl. For each of them you could define (maybe via metaprogramming) a trivial rrule that calls some rrule-like function (maybe a function called loglikelihood_rrule_helper or something).

More generally, the symptom of what DistributionsAD is doing is very much like the symptoms of type-piracy. Consider that I wrote some code in my package, that works just fine with AD when I don't load Turing. Then I load Turing, and it breaks -- this is exactly the kind of thing we try to avoid by avoiding commiting type piracy.

One way of reasoning about this as type-piracy is by considering that when I wrote my code, I also implicitly "wrote" a method of pullback (or whatever function @adjoint spits out) by not defining a method of pullback. i.e. I explicitly intended for the method that Zygote implements automatically to be the one that is used. DistributionsAD then goes and defines a more-specific method of pullback that over-rides the default behaviour.

I will grant you, that you could either construe this as a problem with the way that Zygote works, but it feels like the kind of thing that we should really be solving at the ChainRules level.

@mohamed82008
Copy link
Member

The DistributionsAD approach is breaking all the Julia rules and it needs to go. But this package was born out of the need to "fix" differentiating most of the distributions using all the AD packages. This meant different workarounds for different packages. Some of those "workarounds" made it back to ReverseDiff or were changed to using ChainRules, while others remained. In a way, defining rrules on any type we don't own is type piracy. But doing so on abstract types is especially bad for the reason you outline. So in summary, I am in favour of removing the method in question here if removing it doesn't break anything or if you have a better implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants