New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward mode #175
Forward mode #175
Conversation
Other changes: Node.progenitors is now a dict, with reverse mode progenitors mapped to None and forward mode progenitors mapped to their forward mode gradient. Added a few extra tests to test_systematic.py. Still no tests for forward mode.
autograd/core.py
Outdated
ingrads[progenitor].append(ingrad) | ||
assert_vspace_match(ingrad, result.vspace, self) | ||
result.progenitors.update({progenitor: vsum(result.vspace, *ingrads[progenitor]) | ||
for progenitor in ingrads}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
↑ This block should maybe be moved into a separate update_forward_grads
function
Whoa, is is great! I wonder how much mileage we can get out of adding forward-mode tests to I suppose we really should be using some sort of code coverage tool too, since right now we can silently forget to add tests. |
Yeah, this is awesome! We probably also want to compute Also, I think this forward-mode defines a derivative operation (with respect to a scalar parameter), but we should consider defining a Jacobian-matrix product operation and thus allowing the loop over input dimensions to happen with each primitive (rather than just as an outer loop as in |
Agreed! (see the bottom item of my TODO list)
…On Tue, 3 Jan 2017 at 17:31, Matthew Johnson ***@***.***> wrote:
Yeah, this is awesome! We probably also want to compute hessian
differently, like forward_mode_jacobian(jacobian(fun)) or something.
Also, I think this forward-mode defines a derivative operation (with
respect to a scalar parameter), but we should consider defining a
Jacobian-matrix product operation and thus allowing the loop over input
dimensions to happen with each primitive (rather than just as an outer loop
as in jacobian). That could be a huge speed win, both in avoiding
autograd overhead and in leveraging more numpy to avoid Python overhead. It
may be that those Jacobian-matrix products are hard to implement for all
primitives, but it would be great for some primitives to support it, even
if not all do. (We should think about adding a matrix-Jacobian product
operation for reverse-mode too!)
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#175 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AOjgu7JEsElJubtDNcKeL3gcLBEwPzWxks5rOnfogaJpZM4LY52z>
.
|
I'm wondering if we need a way to make a product vspace, for things like the jacobian, whose vspace should be the product of the input vspace with the output vspace... EDIT: A tensor product to be precise... |
Also someone at NIPS was telling me it's possible to calculate Hessians 'in one go' with a single reverse pass, the method is outlined here http://www.di.ens.fr/~rgower/pdf/HighOrderReverseAD. I imagine this would be faster for us than the multi pass approach but implementation requires defining the Hessian of each primitive, which could be tedious. |
Re: defining Hessian 'pullback' functions for primitives, as in my comment about implementing matrix-Jacobian products, a key advantage of autograd over other tools is its flexibility and simple implementation. I think it would be great to have a clean way to add Hessian operations to primitives, even if we only implement it for a few (as the need arises) and as long as the implementation is small and modular. |
Implement more grads and forward mode tests in combo_check. Have introduced dependency on a module called orderedset, should be able to fix this by just using a list.
I realised my implementation was pretty naive, because I payed no attention to what order forward gradients were calculated within each primitive call... I think things are now correct but I had to use this orderedset thing for the active_progenitors. How do you guys feel about introducing this extra dependency? If necessary I can try to work out another way to do it that isn't too clunky. I've also added forward mode tests to |
@j-towns - that seems like a slightly esoteric dependency which doesn't appear to be very actively maintained or supported - e.g. no (windows) wheels, no conda packages. If you need a sorted set you may be better off using https://www.youtube.com/watch?v=7z2Ki44Vs4E ...that said, if there's a performance benefit to using |
One option would be to use an OrderedDict and have everything point to |
Removed orderedset dependency
Also use jacobian_vector_product for doing hessian_vector_products.
Assuming the TF interpreter isn't factoring out the computation of |
Ah wait I think I'm with you on the runtime FLOPs now — because all of the mappings in the reverse pass are linear as functions of I would still argue that this assertion should be proved by doing some actual timings. |
Yes, that's right. There's a graph to draw here, but I don't have your ascii art skills. For chain-structured functions, have three chains involved here: the top one is from the function eval, the middle one from the first call to Even better, there won't even be extra memory overhead at runtime: because the top chain and bottom chain have arrows pointing in the same direction (i.e. to the right), and the arrows from the top chain to the bottom don't skip steps, the TF graph executor is smart enough to reuse memory as it performs the evaluation, rather than evaluating things in a bad order and keeping too large an active set. One way to convince yourself that |
Reverse mode FTW again! 😎
…On Thu, 8 Jun 2017 at 15:14, Matthew Johnson ***@***.***> wrote:
Yes, that's right. There's a graph to draw here, but I don't have your
ascii art skills. For chain-structured functions, have three chains
involved here: the top one is from the function eval, the middle one from
the first call to tf.gradients, and the bottom one from the second call
to tf.gradients. As you say, the bottom one doesn't actually depend on
the values from the middle one, by linearity. Both the bottom chain and the
middle chain depend on the values of the top chain.
Even better, there won't even be extra memory overhead at runtime: because
the top chain and bottom chain have arrows pointing in the same direction
(i.e. to the right), and the arrows from the top chain to the bottom don't
skip steps, the TF graph executor is smart enough to reuse memory as it
performs the evaluation, rather than evaluating things in a bad order and
keeping too large an active set.
One way to convince yourself that g isn't evaluated at runtime is the
fact that we don't even need to supply a value for v in feed_dict; that
variable stays unbound, and we don't get an error because we never try to
evaluate a node that depends on it.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#175 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AOjgu1RWppjrUN5c6qS2-9oEoVeqaGZmks5sCAGpgaJpZM4LY52z>
.
|
Here is a quick attempt at a graph: Each op in the blue chain is linear in the arguments moving right to left (i.e. coming from All that is my current thinking, but isn't really tested, and there may be TF implementation details that we haven't accounted for. |
Bruce Christianson claims in his paper "A Leibniz Notation for Automatic Differentiation", Theorem 1, "The three algorithms forward-over-reverse, reverse-over-forward and reverse-over-reverse are all numerically stepwise identical, in the sense that they not only produce the same numerical output values, but at every intermediate stage perform exactly the same floating point calculations on the same intermediate variable values" If true, then we can see that the forward-mode calculation of some original function |
I see second derivatives there. Doesn't his theorem generalise to the equivalence we were discussing last week (see above) between different methods for calculating Hessian-vector products? Not to be confused with what we're talking about here, which is Jacobian-vector products. |
As an aside, I agree with his point that Leibniz notation > Newton notation, although what he calls his 'bariential' notation seems unnecessary. Edit: sorry that this thread is getting ridiculously long now... |
Good eye on the 2nd derivatives there! I like that the autograd issues tracker is a hotbed for autodiff discussion, though it would be even better if the information here were eventually distilled into some docs, or even the long-promised autograd paper. |
I've written a blog post on this: https://j-towns.github.io/2017/06/12/A-new-trick.html. @mattjj and @jekbradbury I wasn't sure how to credit you guys, I've linked to this thread at the top of the post, if you guys want me to include your names (or anyone else's) explicitly let me know. |
Just chiming it to say how awesome and satisfying this all is. It's like something out of SICP. It seems obvious in retrospect: How do you get forward-mode? By reversing reverse mode! :) |
Awesome blog post, @j-towns! I think the way you wrote it is great; citing the thread is a good way to share credit. It might have broader audience appeal if you include the TF implementation in the post, too, since TF is so popular right now (maybe linking to my comment in the thread). @duvenaud I agree, and I think we were right on top of this idea for a while without noticing it. |
Good point. I've put a link at the very bottom of the blog post and also on the post I've made on reddit. Glad you like it by the way! 🤓 |
Hmm I'm gonna have a go at running my forward mode implementation, but applying the new vjp(vjp) method on a primitive by primitive basis to get each primitive's jvp (rather than using the jvps that I've hand implemented). It will be interesting to see how much overhead this introduces over the hand implemented jvps (this was discussed a bit above – I reckon there will probably be some overhead). If the overhead is negligible this is great news because it means we only have to maintain the vjps and can get rid of all the jvps. Edit: I guess at the very least this could be a handy fallback for primitives where I haven't written a jvp. |
Finally catching up on this thread. Beautiful work, @j-towns and @mattjj!@duvenaud is right, it's one of those elegant results that's natural and obvious once it's pointed out but takes some real wizardry to get to in the first place. I've been trying to think of ways we can avoid unnecessary FLOPS (FLOs?) in Off the top of my head, here are three options for figuring out linearity:
A bonus side effect of doing this is that we can use the linearity information to avoid storing objects that the vjp doesn't need. As this thread makes clear, people really do care about memory use. Currently, we're needlessly hogging memory with linear operations like matrix multiplies. But here's one major wrinkle: in order to propagate zeros, we need to be able to predict the output vspace of a function without evaluating it. We've basically been building our own type system within Autograd, so function signatures would a natural thing to add. It's a bit tricky because the output vspace isn't just a function of the input vspace(s) but also of (non-float) args like 'keepdims'. It's also a bit of a pain to have to write out all these function signatures. On the other hand, it would be hugely helpful for automated testing of primitives. |
Yeah that sounds like a good template for how we could bring the FLO cost down for jvp/ggnvp. I have a minor question, on the point about avoiding storing objects that aren't needed, would it actually be possible to know ahead of time that a variable isn't going to be needed by any vjp? It seems to me that if the user does something like this: def f(x):
a = np.ones(3)
y = np.dot(a, x)
return y**x Although the third line is linear in Pruning nodes surely wouldn't be able to happen until the whole graph had been constructed, by which point we don't really care, since it's the peak memory usage that we wish to reduce. EDIT: Thinking a bit more, I suppose there would be cases where you could avoid storing nodes, if a variable is created inside some inner scope which is then exited. To get this to work, linear primitives could be in some way tagged as linear, and the nodes which they output can simply not include the parent in their 'recipe'. |
@j-towns , you're raising excellent points. I'm living in the past. There was a brief period of Autograd's history when we separated the value-storing So it looks like what I'm proposing would require re-introducing some sort of |
Ah ok, that all makes sense. I'm just wondering if there's some way we can get the benefits of python's garbage collection without having two separate Node types (which is nasty IMO). What about keeping Node as it is, but instead of having it contain Edit: I guess it wouldn't be possible to efficiently handle fan out and fan in with what I think I was suggesting there. |
A minor point: I think it's kind of elegent having the local primitive vjp's constructed by |
Yup, you got it.
You know, I have to agree. That plus the memory optimization and I'm starting to regret changing the interface. The arguments at the time were (1) slight reduction in overhead and (2) simpler interface for users. If it weren't for the albatross of backwards compatibility I'd change it back right now (keeping the vspaces). |
I say make the change! We can figure out a way to make the transition easy enough. Maybe we could require uses of the non-staged def defvjp(self, vjpfun, argnum=0):
warnings.warn(defvjp_staging_is_better)
def make_vjp(ans, vs, gvs, *args, **kwargs):
def vjp(g): return vjpfun(g, ans, vs, gvs, *args, **kwargs)
return vjp
self.defvjp_staged(make_vjp, argnum) (Not sure where the Is there a problem I'm overlooking with switching between staged and unstaged interfaces like this? IMO autograd's internals should stay agile. |
That's the spirit! Technically, defvjp is external because we mention it in the tutorial. But then again, so is our digestive tract. Ok, I'll make a branch with the changes. It'll be nice to finally benchmark it too. |
Do we think things should be named as they were before, with a |
…todiff from tracing, forward mode falls out quite naturally as an alternative trace type. We still need to implement the primitive jacobian-vector products. Luckily j-towns has done a lot of this already.
This probably isn't ready to be pulled into the master branch yet, but I thought I'd submit a pr in case you want to track progress.
TODO:
jacobian_vector_product
convenience wrapperhessian_vector_product
wrapper to use forward mode