Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting with differentiable functions (Remove Cast?) #12

Closed
willtebbutt opened this issue Apr 12, 2019 · 5 comments
Closed

Broadcasting with differentiable functions (Remove Cast?) #12

willtebbutt opened this issue Apr 12, 2019 · 5 comments

Comments

@willtebbutt
Copy link
Member

The current implementation of broadcast assumes that the function being broadcasted doesn't contain any differentiable bits, and that we can therefore safely assume that there is no gradient information to be associated with it. It also assumes that the forwards-inside-reverse-mode trick is the correct choice for implementing the adjoint, which isn't necessarily the case.

Presumably this implementation is a placeholder, however, it will definitely be necessary to relax the above assumptions before e.g. Zygote is able to adopt ChainRules, so I believe it should be addressed as a priority.

@jrevels
Copy link
Member

jrevels commented Apr 14, 2019

The current implementation of broadcast assumes that the function being broadcasted doesn't contain any differentiable bits

Relevant dev note (will turn into an advanced usage doc section or something at some point):

Presumably this implementation is a placeholder, however, it will definitely be necessary to relax the above assumptions before e.g. Zygote is able to adopt ChainRules, so I believe it should be addressed as a priority.

It's a half-placeholder 😉kind of expresses the idea, but the implementation is mainly a toy one compared to e.g. https://github.com/jrevels/MixedModeBroadcastAD.jl.

Regardless - unless I'm misunderstanding something, which could always be the case - none of the default rule definitions should be a barrier to adoption by downstream AD, since downstream ADs can overload rules in whatever manner best fits their specific implementations. Zygote can define the broadcast rule in whatever way makes the most sense for Zygote.

It also assumes that the forwards-inside-reverse-mode trick is the correct choice for implementing the adjoint, which isn't necessarily the case.

It assumes that forward mode is the correct choice of "fallback" mode for unary scalar functions, which is almost certainly the case. It doesn't make any assumptions about non-unary functions, and/or functions where there's a specialized broadcast rule that could be defined ahead of time.

Anyway, even if we had a more general n-ary fallback rule for broadcast, such a rule would require selecting a default fallback mode. Forward-mode is generally a much safer choice for that, since broadcast kernels tend to be low arity (almost never >~100 arguments and often <~10) and forward-mode puts less constraints on kernel expressiveness than reverse-mode. Obviously, in a non-fallback situation, one can do better for kernels that have known differentiation rules ahead of time, or even leverage IR analysis + runtime type info to make a more intelligent guess about the appropriate mode. But that kind of approach is more in downstream ADs' domain and probably isn't something we'd want ChainRules to dictate.

@willtebbutt
Copy link
Member Author

Okay sounds good to me. Could you elaborate a little on how a downstream package would change the default behaviour? If the intention is to literally override the default method in ChainRules, will this not cause annoying warnings? (Not that this is the end of the world of course...)

@willtebbutt
Copy link
Member Author

Or is the intention that Cassette will come to the rescue?

@oxinabox oxinabox changed the title Broadcasting with differentiable functions Broadcasting with differentiable functions (Remove Cast?) Aug 23, 2019
@oxinabox
Copy link
Member

related #122

@oxinabox
Copy link
Member

Most parts of this are resolved or put in other issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants