-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The Extensibility Problem, for propagator closures #53
Comments
I think this is something to keep in mind. This also relates to the fact you can't actually pass Interesting idea: if we did go to pullback as a global function, taking a signature, and (Ȳ) and some extra information, we might be able to encode that extra information as a closure that has the default implementation. Ideally, we would dispatch the But that requires global information, some of which violates the halting problem. The input to the pullback (Ȳ) has to be a very similar type (need to be able to subtract them, I think?) to the output of the forward pass (Y). Definately this issue is one to think on |
How about this. It can solve a few problems:
We also need a function
So this is what it might looklike:
Bonus fact, that may or maynot apply to storing sig as part of namedtuple It might kinda be part of replacing |
I'd be interested in more specifics of the use case for this and the kind of extensibility you need. My main issue with a separate pullback function is that it simulates closures (i.e. a bundle of data + code) anyway; you're going to end up with something equivalent but much less nice to use. function forward(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
return A * B, C -> pullback((signature=(typeof(*), typeof(A), typeof(B)), A=A, B=B), C)
end Of course, writing things out this way isn't that helpful if you have to do it for every rule. But the only real difference is that the pullback has a name, which gives you an interface to overload it. We could get the same effect with something like (with appropriate sugar) P = pullback_name(Tuple{typeof(*),AbstractMatrix,AbstractMatrix})
(::typeof(P))(Y) = ... However, while this would solve the problem in a general way, I'm sceptical that even this is really needed. Why can't you make the adjoint a named type, rather than a named tuple, and overload |
That will be the case with #8 , I am calling that
The main case is that being a linear map practically is not enough for some pullbacks. |
This is tricky to do when mixing-and-matching custom adjoints with automatically derived ones as automatically derived rules will always produce a |
Sure, but there's a finite number of such adjoints (I think just Customising your adjoint type is something we can define a clear interface for and it'll work with custom adjoints defined in outside packages, whereas if you manually override each pullback it's only going to work with the set you specifically overloaded. |
I am in favour of waiting and seeing. We will certainly have a way to convert a Particularly, since under my plan the extra info you need will be housed in the default closure propagator anyway. |
Here is an instance of this in the wild for Zygote It special cases So I think we should do |
Because of math reasons, it is very rare to get a unexpected type being passed to the pullback. Thus generally one just adds another There is a bit more to this story w.r.t. arrays but we have a bit more of that story encoded eg. in ProjectTo |
Here's a thing that isn't currently possible and is, I believe, something that we might actually want to care about. Consider the pullback for
AbstractMatrix
multiplication:Provided that
Ȳ
is itself anAbstractMatrix
, for which*
with other matrices will be correctly defined assuminggetindex
is correctly defined, something correct will happen even if it's slow.Now consider the case that
Ȳ
is aNamedTuple
, possibly becauseY
is some non-Matrix
AbstractMatrix
. Now what happens? The above breaks:*
isn't defined forNamedTuple
s, nor is it possible to extendtimes_pullback
to handleȲ
from outside the originalrrule
definition. One's only recourse is to add a completely new definition ofpullback
for*
withAbstractTypeofA
andAbstractTypeofB
one expects to seeȲ
with, which itself presupposes a method ofpullback(typeof(*), ::AbstractTypeofA, ::AbstractTypeofB)
doesn't already exist, in which case no option is available but to modify the existing method.Phrased differently, the current design requires each
rrule
must be implemented to handle every possible type ofȲ
that it might ever see. This is clearly an unreasonable requirement because Julia permits the creation of new types, and the input types topullback
to do not uniquely specify the type ofȲ
. Assuming this unreasonably requirement unmet, we are left with two options when the above is encountered:rrule
to handle the new type encountered (the bad-ness of this is assumed self-evident)rrule
specialised to different types ofA
andB
. This really isn't great from a code re-use perspective.*
is not a pathological case because the forwards-pass is quite straightforward, but other cases are worse.This problem appears to manifest itself in cases where the forwards-pass is perfectly good for multiple types, but the reverse-pass requires care. For example
Diagonal * Matrix
: the forwards-pass and data required on the reverse-pass is no different thanMatrix * Matrix
, but the reverse-pass implementation is necessarily quite different.This lack of extensibility is a direct consequence of the
(value, back) = pullback(...)
design choice thatChainRules
/Zygote
make.Nabla
made a slightly different design choice in which the forwards- and reverse- bits of apullback
were separate functions, so you could extend things. This design doesn't share the extensibility issue that theChainRules
/Zygote
style presents, but equally doesn't immediately enable the same sharing of state on the forwards- and reverse-passes.One possible resolution would be to adopt the separated forwards- and reverse- passes chosen in
Nabla
, and allow an arbitrary communication object to be shared between the forwards- and reverse- passes.A
forward
call would therefore not return a closure, but rather whatever intermediate data is deemed by the implementer to be important for the reverse-pass.pullback
is then called with the appropriate signature to evaluate the adjoint. This interface is somewhat more verbose than the closure-based interface due to the need to copy-paste the signature all over the place, although this may be alleviated with some careful metaprogramming tooling. I would anticipate that we would also see an improvement in stack-trace readability, since we would get direct calls to apullback
function, with the types offorward
s arguments placed prominently.In summary, this change buys us the ability to extend reverse-pass behaviour using multiple dispatch, and hence code-reuse, at the expense of increased verbosity and the need to explicitly specify the data from the forwards-pass that may be required on the reverse.
This isn't something that needs resolving immediately, but I feel it should be given some consideration so that we can at least be aware that this is an issue we're choosing to ignore if nothing is done about it.
Side note: this appears analogous to be similar to the expression problem, which you can consult Stefan's JuliaCon talk on. Specifically, that we can't define new methods of
back
for existingpullback
s is (I think) analogous to not being able to define new methods that extend the functionality of existing types.The text was updated successfully, but these errors were encountered: