Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward mode #175

Closed
wants to merge 22 commits into from
Closed

Forward mode #175

wants to merge 22 commits into from

Conversation

j-towns
Copy link
Collaborator

@j-towns j-towns commented Jan 2, 2017

This probably isn't ready to be pulled into the master branch yet, but I thought I'd submit a pr in case you want to track progress.

TODO:

  • Implement the rest of the numpy grads
  • Tests for remaining grads
  • Write a jacobian_vector_product convenience wrapper
  • Update the hessian_vector_product wrapper to use forward mode
  • Ensure that nodes with only forward mode grads don't refer to their parents (so that garbage collection can work)
  • Implement a jacobian matrix product

Other changes:
Node.progenitors is now a dict, with reverse mode progenitors mapped to None and forward mode progenitors mapped to their forward mode gradient.
Added a few extra tests to test_systematic.py. Still no tests for forward mode.
autograd/core.py Outdated
ingrads[progenitor].append(ingrad)
assert_vspace_match(ingrad, result.vspace, self)
result.progenitors.update({progenitor: vsum(result.vspace, *ingrads[progenitor])
for progenitor in ingrads})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

↑ This block should maybe be moved into a separate update_forward_grads function

@duvenaud
Copy link
Contributor

duvenaud commented Jan 3, 2017

Whoa, is is great!

I wonder how much mileage we can get out of adding forward-mode tests to combo_check. I suggest adding a check to see if the forward gradient is implemented, and if so, test it along a randomly-chosen input dimension without calling to_scalar on the output. This way we get most of the necessary tests for free.

I suppose we really should be using some sort of code coverage tool too, since right now we can silently forget to add tests.

@mattjj
Copy link
Contributor

mattjj commented Jan 3, 2017

Yeah, this is awesome! We probably also want to compute hessian differently, like forward_mode_jacobian(jacobian(fun)) or something.

Also, I think this forward-mode defines a derivative operation (with respect to a scalar parameter), but we should consider defining a Jacobian-matrix product operation and thus allowing the loop over input dimensions to happen with each primitive (rather than just as an outer loop as in jacobian). That could be a huge speed win, both in avoiding autograd overhead and in leveraging more numpy to avoid Python overhead. It may be that those Jacobian-matrix products are hard to implement for all primitives, but it would be great for some primitives to support it, even if not all do. (We should think about adding a matrix-Jacobian product operation for reverse-mode too!)

@j-towns
Copy link
Collaborator Author

j-towns commented Jan 3, 2017 via email

@j-towns
Copy link
Collaborator Author

j-towns commented Jan 5, 2017

I'm wondering if we need a way to make a product vspace, for things like the jacobian, whose vspace should be the product of the input vspace with the output vspace...

EDIT: A tensor product to be precise...

@j-towns
Copy link
Collaborator Author

j-towns commented Jan 5, 2017

Also someone at NIPS was telling me it's possible to calculate Hessians 'in one go' with a single reverse pass, the method is outlined here http://www.di.ens.fr/~rgower/pdf/HighOrderReverseAD. I imagine this would be faster for us than the multi pass approach but implementation requires defining the Hessian of each primitive, which could be tedious.

@mattjj
Copy link
Contributor

mattjj commented Jan 11, 2017

Re: defining Hessian 'pullback' functions for primitives, as in my comment about implementing matrix-Jacobian products, a key advantage of autograd over other tools is its flexibility and simple implementation. I think it would be great to have a clean way to add Hessian operations to primitives, even if we only implement it for a few (as the need arises) and as long as the implementation is small and modular.

Implement more grads and forward mode tests in combo_check. Have introduced dependency on a module called orderedset, should be able to fix this by just using a list.
@j-towns
Copy link
Collaborator Author

j-towns commented Jan 18, 2017

I realised my implementation was pretty naive, because I payed no attention to what order forward gradients were calculated within each primitive call... I think things are now correct but I had to use this orderedset thing for the active_progenitors. How do you guys feel about introducing this extra dependency? If necessary I can try to work out another way to do it that isn't too clunky.

I've also added forward mode tests to combo_check, which seems to cover a lot of things.

@dhirschfeld
Copy link
Contributor

dhirschfeld commented Jan 18, 2017

@j-towns - that seems like a slightly esoteric dependency which doesn't appear to be very actively maintained or supported - e.g. no (windows) wheels, no conda packages.

If you need a sorted set you may be better off using sortedcontainers or the higher level sortedcollections which seem to have a bit of momentum behind them.

https://www.youtube.com/watch?v=7z2Ki44Vs4E

...that said, if there's a performance benefit to using orderedset it's probably not too hard to get it set up to build packages on conda-forge

@j-towns
Copy link
Collaborator Author

j-towns commented Jan 18, 2017

One option would be to use an OrderedDict and have everything point to None.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 8, 2017

Assuming the TF interpreter isn't factoring out the computation of g, perhaps this optimization is the kind of thing which could be written into XLA.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 8, 2017

Ah wait I think I'm with you on the runtime FLOPs now — because all of the mappings in the reverse pass are linear as functions of v, TF can see that the nodes in the final graph needn't depend on what's computed during the reverse pass, i.e. the computation of g. Is that correct?

I would still argue that this assertion should be proved by doing some actual timings.

@mattjj
Copy link
Contributor

mattjj commented Jun 8, 2017

Yes, that's right. There's a graph to draw here, but I don't have your ascii art skills. For chain-structured functions, have three chains involved here: the top one is from the function eval, the middle one from the first call to tf.gradients, and the bottom one from the second call to tf.gradients. As you say, the bottom one doesn't actually depend on the values from the middle one, by linearity. Both the bottom chain and the middle chain depend on the values of the top chain.

Even better, there won't even be extra memory overhead at runtime: because the top chain and bottom chain have arrows pointing in the same direction (i.e. to the right), and the arrows from the top chain to the bottom don't skip steps, the TF graph executor is smart enough to reuse memory as it performs the evaluation, rather than evaluating things in a bad order and keeping too large an active set.

One way to convince yourself that g isn't evaluated at runtime is the fact that we don't even need to supply a value for v in feed_dict; that variable stays unbound, and we don't get an error because we never try to evaluate a node that depends on it.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 8, 2017 via email

@mattjj
Copy link
Contributor

mattjj commented Jun 8, 2017

Here is a quick attempt at a graph:

image

Each op in the blue chain is linear in the arguments moving right to left (i.e. coming from v). Due to that linearity, the orange chain doesn't depend on the values in the blue chain. We can evaluate y and J u given values for x and u using a maximal active set size of two nodes, and indeed this is the same graph that would be produced by a manual forward-mode implementation.

All that is my current thinking, but isn't really tested, and there may be TF implementation details that we haven't accounted for.

@alexbw
Copy link
Contributor

alexbw commented Jun 8, 2017

Bruce Christianson claims in his paper "A Leibniz Notation for Automatic Differentiation", Theorem 1, "The three algorithms forward-over-reverse, reverse-over-forward and reverse-over-reverse are all numerically stepwise identical, in the sense that they not only produce the same numerical output values, but at every intermediate stage perform exactly the same floating point calculations on the same intermediate variable values"

If true, then we can see that the forward-mode calculation of some original function f must be buried in, thus extractable from via dead-code elimination, a reverse-over-reverse application of differentiation to f. Spooky.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 8, 2017

I see second derivatives there. Doesn't his theorem generalise to the equivalence we were discussing last week (see above) between different methods for calculating Hessian-vector products? Not to be confused with what we're talking about here, which is Jacobian-vector products.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 8, 2017

As an aside, I agree with his point that Leibniz notation > Newton notation, although what he calls his 'bariential' notation seems unnecessary.

Edit: sorry that this thread is getting ridiculously long now...

@mattjj
Copy link
Contributor

mattjj commented Jun 8, 2017

Good eye on the 2nd derivatives there!

I like that the autograd issues tracker is a hotbed for autodiff discussion, though it would be even better if the information here were eventually distilled into some docs, or even the long-promised autograd paper.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 12, 2017

I've written a blog post on this: https://j-towns.github.io/2017/06/12/A-new-trick.html. @mattjj and @jekbradbury I wasn't sure how to credit you guys, I've linked to this thread at the top of the post, if you guys want me to include your names (or anyone else's) explicitly let me know.

@duvenaud
Copy link
Contributor

Just chiming it to say how awesome and satisfying this all is. It's like something out of SICP.

It seems obvious in retrospect: How do you get forward-mode? By reversing reverse mode! :)

@mattjj
Copy link
Contributor

mattjj commented Jun 12, 2017

Awesome blog post, @j-towns! I think the way you wrote it is great; citing the thread is a good way to share credit. It might have broader audience appeal if you include the TF implementation in the post, too, since TF is so popular right now (maybe linking to my comment in the thread).

@duvenaud I agree, and I think we were right on top of this idea for a while without noticing it.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 12, 2017

Good point. I've put a link at the very bottom of the blog post and also on the post I've made on reddit.

Glad you like it by the way! 🤓

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 13, 2017

Hmm I'm gonna have a go at running my forward mode implementation, but applying the new vjp(vjp) method on a primitive by primitive basis to get each primitive's jvp (rather than using the jvps that I've hand implemented).

It will be interesting to see how much overhead this introduces over the hand implemented jvps (this was discussed a bit above – I reckon there will probably be some overhead).

If the overhead is negligible this is great news because it means we only have to maintain the vjps and can get rid of all the jvps.

Edit: I guess at the very least this could be a handy fallback for primitives where I haven't written a jvp.

@j-towns j-towns reopened this Jun 13, 2017
@j-towns j-towns closed this Jun 13, 2017
@dougalm
Copy link
Contributor

dougalm commented Jun 18, 2017

Finally catching up on this thread. Beautiful work, @j-towns and @mattjj!@duvenaud is right, it's one of those elegant results that's natural and obvious once it's pointed out but takes some real wizardry to get to in the first place.

I've been trying to think of ways we can avoid unnecessary FLOPS (FLOs?) in make_jvp and make_ggnvp. Here's one approach: we can introduce a special 'zero vector' object which the primitive wrapper can recognize and propagate FLO-free. But it's only ok to propagate a zero if the primitive is linear in that argument. We know the vjps are all linear, but we need to trace their component functions, which can be arbitrary primitives (although they're almost certainly linear in practice). One trick for figuring out linearity is to use the fact that a primitive is linear in one of its args iff its vjp doesn't use that arg in its definition (aside from things like extracting shape information, but our vspaces should make that unnecessary).

Off the top of my head, here are three options for figuring out linearity:

  1. Inspect a function's body to detect whether it uses each arg. CPython must do some version of this internally because it garbage-collects variables present in the environment of a closure when the inner function doesn't use those variables. Maybe this machinery is exposed somewhere.
  2. Use special variable names in vjp definitions to denote 'unused variable', which we can easily extract with the inspect module. I wish we could use _, but it's a regular identifier and can't be used twice in a function's args.
  3. Explicitly tag the primitive with the linearity information. We will probably always need this option as a fallback for functions that can take an unbounded number of arguments, like concatenate_args.

A bonus side effect of doing this is that we can use the linearity information to avoid storing objects that the vjp doesn't need. As this thread makes clear, people really do care about memory use. Currently, we're needlessly hogging memory with linear operations like matrix multiplies.

But here's one major wrinkle: in order to propagate zeros, we need to be able to predict the output vspace of a function without evaluating it. We've basically been building our own type system within Autograd, so function signatures would a natural thing to add. It's a bit tricky because the output vspace isn't just a function of the input vspace(s) but also of (non-float) args like 'keepdims'. It's also a bit of a pain to have to write out all these function signatures. On the other hand, it would be hugely helpful for automated testing of primitives.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 19, 2017

Yeah that sounds like a good template for how we could bring the FLO cost down for jvp/ggnvp. I have a minor question, on the point about avoiding storing objects that aren't needed, would it actually be possible to know ahead of time that a variable isn't going to be needed by any vjp?

It seems to me that if the user does something like this:

def f(x):
    a = np.ones(3)
    y = np.dot(a, x)
    return y**x

Although the third line is linear in x, Autograd has no way of knowing (until the function has returned/the graph has been completed) that x won't be needed somewhere else during the reverse pass.

Pruning nodes surely wouldn't be able to happen until the whole graph had been constructed, by which point we don't really care, since it's the peak memory usage that we wish to reduce.

EDIT: Thinking a bit more, I suppose there would be cases where you could avoid storing nodes, if a variable is created inside some inner scope which is then exited. To get this to work, linear primitives could be in some way tagged as linear, and the nodes which they output can simply not include the parent in their 'recipe'.

@dougalm
Copy link
Contributor

dougalm commented Jun 19, 2017

@j-towns , you're raising excellent points. I'm living in the past. There was a brief period of Autograd's history when we separated the value-storing Node from the graph-representing ReverseNode. In fact, this sort of memory optimization was the whole motivation for the ReverseNode. In those days, the vjp functions returned a closure and Python was clever enough not to store those variables that weren't used in the inner function. The only data structure you'd end up after evaluating your composite function was a tape of ReverseNodes, with the values themselves stored implicitly in the closures as needed. The sad thing is that we ended up adding value to ReverseNode to track (generalized) shape information, which completely defeated the purpose. Now that we have vspaces, we could just store those instead.

So it looks like what I'm proposing would require re-introducing some sort of ReverseNode. The memory saving are pretty compelling so it might be worth it. And we have benchmarks now to see whether the overhead is really anything to worry about. We could even go back to the nested functions of defgrad to get automatic linearity detection but I feel bad constantly swapping out the the interface from under people's feet.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 19, 2017

Ah ok, that all makes sense. I'm just wondering if there's some way we can get the benefits of python's garbage collection without having two separate Node types (which is nasty IMO).

What about keeping Node as it is, but instead of having it contain vjp and recipe attributes, have it contain just a vjp, which is a function that evaluates the vjp from the Node all the way back to its progenitor. This would also mean that backward_pass would simply evaluate the final Node's vjp function, removing the need for the toposort and manually looping over the graph. Please shoot me down if there are obvious reasons why that's bad!

Edit: I guess it wouldn't be possible to efficiently handle fan out and fan in with what I think I was suggesting there.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 19, 2017

A minor point: I think it's kind of elegent having the local primitive vjp's constructed by make_vjp closures, since this matches the global make_vjp function.

@dougalm
Copy link
Contributor

dougalm commented Jun 20, 2017

Edit: I guess it wouldn't be possible to efficiently handle fan out and fan in with what I think I was suggesting there.

Yup, you got it.

A minor point: I think it's kind of elegent having the local primitive vjp's constructed by make_vjp closures, since this matches the global make_vjp function.

You know, I have to agree. That plus the memory optimization and I'm starting to regret changing the interface. The arguments at the time were (1) slight reduction in overhead and (2) simpler interface for users. If it weren't for the albatross of backwards compatibility I'd change it back right now (keeping the vspaces).

@mattjj
Copy link
Contributor

mattjj commented Jun 20, 2017

I say make the change! We can figure out a way to make the transition easy enough. Maybe we could require uses of the non-staged defvjp method to be rewritten as defvjp_unstaged to keep working, which isn't too much of a burden. Or we could add a defvjp_staged method for the new-old way.

def defvjp(self, vjpfun, argnum=0):
    warnings.warn(defvjp_staging_is_better)
    def make_vjp(ans, vs, gvs, *args, **kwargs):
        def vjp(g): return vjpfun(g, ans, vs, gvs, *args, **kwargs)
        return vjp
    self.defvjp_staged(make_vjp, argnum)

(Not sure where the ans, vs, and gvs arguments should go, but you get the idea.)

Is there a problem I'm overlooking with switching between staged and unstaged interfaces like this?

IMO autograd's internals should stay agile.

@dougalm
Copy link
Contributor

dougalm commented Jun 20, 2017

That's the spirit!

Technically, defvjp is external because we mention it in the tutorial. But then again, so is our digestive tract.

Ok, I'll make a branch with the changes. It'll be nice to finally benchmark it too.

@j-towns
Copy link
Collaborator Author

j-towns commented Jun 20, 2017

Do we think things should be named as they were before, with a Node and ReverseNode? I'm wondering if ReverseNode should really be called Node, since the reverse nodes form a graph, and then what used to be called Node could be renamed something like ActiveVariable?

@j-towns j-towns mentioned this pull request Jul 19, 2017
j-towns referenced this pull request Aug 21, 2017
…todiff from

tracing, forward mode falls out quite naturally as an alternative trace
type. We still need to implement the primitive jacobian-vector products. Luckily
j-towns has done a lot of this already.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants