Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seperating frule and pushforward prevents efficient solutions (fuse pushforward) #74

Closed
ChrisRackauckas opened this issue Dec 28, 2019 · 25 comments · Fixed by #88
Closed

Comments

@ChrisRackauckas
Copy link
Member

When implementing the frules, I realized that the implementation of frule doesn't allow for standard optimizations that are seen in forward rule implementations. The reason is because the way that forward mode works is that it propagates the derivative along step by step. A good primer on all of this is this set of notes:

https://mitmath.github.io/18337/lecture9/autodiff_dimensions

Essentially what we are trying to do with ChainRules.jl is allow the user to describe how to calculate f(x) and f'(x)v, the primal and the jvp. Currently the formulation is:

function frule(::typeof(foo), args; kwargs...)
    ...
    return y, pushforward
end

where pushforward is pushforward(dargs). However, given that discussion of forward mode differentiation, one can see that this runs contrary to how it is actually calculated. Here's two examples of it

Example 1: Implementing ForwardDiff over frules

As described in the notes, the dual number way of computing forward mode starts by seeding dual numbers. In standard ForwardDiff usage, these seeds are all unique basis vectors, like is shown in the DiffEq documentation for how to AD through the solver manually:

https://docs.juliadiffeq.org/v6.8/analysis/sensitivity/#Examples-using-ForwardDiff.jl-1

But as mentioned in the notes, what this is really doing is seeding the duals in the basis vector e_i directions, so then the jvp is computing J*e_1,J*e_2,J*e_3 as separate vectors, giving a representation of the full Jacobian. If you do get the whole Jacobian, then you can do J*v of course, and this is what the current `frule would allow:

function frule(::typeof(f), x; kwargs...)
    dual_x = seed_duals(x) # seeds along the basis vector directions
   # this gives a dual number of length(x) dimensions
    dual_y = f(x)
    y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is Jacobian
    function pushforward(dx)
       dy*dx
    end
    return y, pushforward
end

However, this shows that there is a more efficient way to calculate y,dy*dx though, since if we know the dx at the start, we can just seed the dual numbers along the direction of of dx, which changes the number of dual dimensions from length(x) to 1:

function frule(::typeof(f), x, dx; kwargs...)
    dual_y = f(dual.(x,dx)) # 2 dimensional number
    y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is f'(x)*dx
    return y, dy 
end

This changes it from an O(n) computation to O(1)!

Example 2: Implementation of Forward Sensitivity Analysis for ODEs

Now here's a bit more concrete example for the user side. For ODEs, you want to look at:

u' = f(u,p,t)

and you want to know how u(t) changes w.r.t. p. So take the d/dp of both sides of the ODE and by the chain rule you get (swap integrals, assume nice properties)

d/dt du/dp = df/du du/dp + df/dp

calling S = du/dp, this is just

S' = (df/du)*S + df/dp

So you get another ODE that gives you the derivatives of the solution of the original ODE w.r.t parameters. This is the continuous pushforward rule! Now the difficulty is that you need to be able to calculate (df/du)(t) which requires that you know u(t). Now in theory you could calculate u(t) a continuous solution beforehand by solving the previous ODE and storing it, but that's not the good way to do it. The way you do it is just realize that, if you solve the ODE:

u' = f(u,p,t)
S' = (df/du)*S + df/dp

together, then you always know u since it's the first part of the equation! So magic happens and this is very efficient.

That's almost there. What sensitivities are we pushing forward though? You can seed the sensitivities from S=0 and the output S = du/dp, but that's not satisfying. What if you wanted to know du/d(u0) and du/dp? Since concrete_solve(p,u0,odeprob,solver,...) is a function of both p and u0, we want the derivative of the ODE's solution with respect to the p and the u0.

It turns out from simple math that all you have to do is set S = du0! So then, in "composed frule" notation, you'd do the following:

function frule(::typeof(concrete_solve),p,dp,u0,du0,odeprob)
  S = du0
  _prob = build_bigger_ode(odeprob,[u0,S])
  sol = solve(_prob,solver)
  y,dy = split_solution(sol)
  y,dy.*dp # weigh by the direction vector!
end

Right now, this can't really be expressed.

API

Actually having those arguments might be difficult, so maybe it's easier to write as:

function frule(::typeof(f), x, dx; kwargs...)
    function pushforward(dx)
        dual_y = f(dual.(x,dx)) # 2 dimensional number
        y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is f'(x)*dx
        return y, dy 
    end
end

Anyways, the exact API is an interesting question, but whatever it is, the computation should have the x and the dx at the same time.

@oxinabox
Copy link
Member

If we exposed the seed as an argument to the outer function,
It would also solve #53 in the case of frule
since the thing one would want to overload to handle different types of differential would just be the frule, not a closure hidden inside.

And avoiding creating the closure would likely make the compiler work better.

@oxinabox
Copy link
Member

oxinabox commented Dec 28, 2019

To break down the ODE example a bit more and why we can't solve it with current frule + pullback:

If we wanted to do

function frule(::typeof(concrete_solve),p,dp,u0,du0,odeprob)
  S = du0
  _prob = build_bigger_ode(odeprob,[u0,S])
  sol = solve(_prob,solver)
  y,dy = split_solution(sol)
  y,dy.*dp # weigh by the direction vector!
end

right now we have two options:

  • Option one: outside the pullback, use de solve then normal system to get y , then build_bigger_ode + solve + split_solution inside the pullback to get dy. But this is wasteful because now we are redoing work in the pullback to compute y the and we can’t remove that part from the bigger _prob becaue it's not independent of solving for the deriviatives.

  • Option two: outside the pullback we call build_bigger_ode + solve + split_solution, and remember the result for dy which we use in the pullback. But that’s no good as build_bigger_ode needs du0 i.e. the sensitivitivies to its inputs in order to work, which we don’t have until the pushforward is called. We could replace du0 it with the basis, run many times and compute the jacobian to use in the pullback but that’s way more work. And in the pushforward we would have have to multiply the jacobian by the du0 which itself might be a nontrivial matmul.

Both are terrible.

YingboMa added a commit to YingboMa/ForwardDiff2.jl that referenced this issue Dec 29, 2019
@oxinabox oxinabox changed the title Optimization of frules and forward AD implementation seperating frule and pushforward prevents efficient solutions Dec 29, 2019
@oxinabox oxinabox changed the title seperating frule and pushforward prevents efficient solutions Seperating frule and pushforward prevents efficient solutions Dec 29, 2019
@oxinabox
Copy link
Member

oxinabox commented Dec 29, 2019

What should the API be?
I like of like the idea of:

linrule(f, (args...), (dself, dargs...); kwargs...) = ((fvals...), (dfvals...))

@YingboMa
Copy link
Member

The API that I implemented in ForwardDiff2.jl is

julia> function ChainRulesCore.frule(::typeof(cumsum), x::AbstractVector, dx)
           function f_pushforward(x, dx)
               return cumsum(x), cumsum(dx, dims=1)
           end
       end

julia> a, b = rand(5), rand(5); f_pushforward = frule(cumsum, a, b); f_pushforward(a, b)
([0.0025809013385238444, 0.7648493671053664, 1.1185434888949843, 1.8270179446376797, 2.350927101459723], [0.24271532595247858, 0.8825882669221625, 1.3966390333846403, 2.07871296820404, 3.015967442534456])

julia> a, b = rand(5), rand(5, 5); f_pushforward(a, b)
([0.9997930779816582, 1.5638602684445428, 2.0694747840930843, 2.895235231851567, 3.693583667087273], [0.5907283915879293 0.1977100837168979  0.23178849580431904 0.7769216080950119; 1.1754562291522375 0.7468940097801302  0.47644964933541734 1.1340193375193208;  ; 2.3833577218064717 2.4033854494631264  1.0361063668932011 1.8776613309790917; 2.5465035717134303 3.1971200202671013  1.467669710763345 2.6282230037188343])

It makes the implementation easier because there is no enclosed values, and I don't need to overdub the frule function either. See https://github.com/YingboMa/ForwardDiff2.jl/blob/4d0106e31427c37903c2e9f725c3f59a17d75de0/src/dual_context.jl#L82-L97

@oxinabox
Copy link
Member

I am really not a fan of returning a function with exact same inputs as the original.
It feels so redundant.

@YingboMa
Copy link
Member

Maybe just

julia> function ChainRulesCore.frule(::typeof(cumsum), x::AbstractVector, dx)
               return cumsum(x), cumsum(dx, dims=1)
       end

@oxinabox
Copy link
Member

Should we wrap the dx in tuple? And thex? especially for the case of multivariate functions, to make it clear where each starts and ends... But that would make type specifying harder, hmmm

@willtebbutt
Copy link
Member

willtebbutt commented Dec 31, 2019

I think it's pretty clear from these discussions that the way forwards in terms of what we return is just the output - no closures over anything.

In terms of the inputs, you could use a dual-numbers inspired approach whereby each input becomes a tuple containing its primal and differential. So you could have something like

function frule(::typeof(foo), (x, dx), (y, dy), ...; kwargs...)
    return 
end

This is what we currently do in FiniteDifferences to implement jvp. I'm a little wary of proposals that group together all of the xs and dxs for functions of multiple inputs, just because you would have to be careful about how things are positioned.

It's worth reiterating that, as @oxinabox pointed out, this means that our expression problem #53 isn't an issue for forward rules.

@oxinabox
Copy link
Member

oxinabox commented Jan 2, 2020

I do like @willtebbutt 's solution.
It would also make spatting easier to deal with for splatted arguments.

And I suspect it will be easier for ForwardDiff2 purposes?

It does however mean we would have to write type signatures with Tuple's everywhere.

function frule(::typeof(foo), (x, dx)::Tuple{Real, Any}, (y, dy)::Tuple{Vector{<:Real}, Any}, ...; kwargs...)
    return 
end

But macros could fix that up. (#44)

@YingboMa do you have preference between yours and Will's proposal?
If you're willing to make the changes then I am happy to go with which ever you think.

@oxinabox
Copy link
Member

oxinabox commented Jan 3, 2020

actually the type-signature would be:

function frule((_,dself)::Tuple{typeof(foo), Any}, (x, dx)::Tuple{Real, Any}, (y, dy)::Tuple{Vector{<:Real}, Any}, ...; kwargs...)
    return 
end

I do kinda like how it pairs the dself with the function, makes it much clearer what it belongs to.

@YingboMa
Copy link
Member

YingboMa commented Jan 3, 2020

I think you might be meant (_,dself)::Tuple{Any, typeof(foo)}?

@oxinabox
Copy link
Member

oxinabox commented Jan 3, 2020

I think you might be meant (_,dself)::Tuple{Any, typeof(foo)}?

No? dself has the type of the differential for the sensitivity of the fields of foo if it's a functor.
For some functors they might be valid differential types. But as a rule they are not.
This is is more likely to be a Composite{::typeof(foo)}

Or in the most common case when foo is a function (or a functor with no fields),
it can be anything, since it is going to be ignored.
But the tests mostly pass it as NamedTuple() (though not really for a good reason anymore, better is probably Zero()?)

@oxinabox
Copy link
Member

oxinabox commented Jan 8, 2020

@YingboMa do you plan on doing this change?
Should we hold off a 0.5 release to wait for it?
If it's not going to be soon, probably better to not hold of 0.5 for it,
then we can be getting ChainRules updated to the other changes

@YingboMa
Copy link
Member

YingboMa commented Jan 8, 2020

Don't hold 0.5 for this. I don't have much time recently.

@YingboMa
Copy link
Member

I will do it this weekend.

@willtebbutt
Copy link
Member

We never actually agreed on the desired resolution to this issue. Implemented in #88 is the solution proposed above by @YingboMa , which was merged rather hastily over the weekend.

@YingboMa @ChrisRackauckas : are either of you especially tied to the implemented solution? FWIW I much prefer the one I proposed -- seems less error prone and has worked well thus far in FiniteDifferences.

@willtebbutt willtebbutt reopened this Jan 13, 2020
@YingboMa YingboMa removed their assignment Jan 14, 2020
@YingboMa
Copy link
Member

I don't think that is a good API for AD implementers. One would have to zip the primals and partials.

@oxinabox
Copy link
Member

Why would one need tozip them?

For

xd::Dual{Float64}
yd::DualArray{Float64, 1}
foo(x::Real, y::Vector)

Its:

frule((foo,Zero()), (value(xd), partial(xd)), (data(y), partials(yd))

vs:

frule(foo, value(xd)  data(y), Zero(), partial(xd), partials(yd))

I don't see the zip in either?

@YingboMa
Copy link
Member

If you have overdub(..., f, arg...), then you have to zip map(value, arg) and map(partials, arg).

@YingboMa
Copy link
Member

Also, I am no super thrilled about dispatching on Tuple{<:T, ::Any}.

@oxinabox
Copy link
Member

If you have overdub(..., f, arg...), then you have to zip map(value, arg) and map(partials, arg).

Right.

I think thats a good argument.

@willtebbutt
Copy link
Member

Okay, thanks for the explanation @YingboMa . Seems reasonable to me :)

@oxinabox
Copy link
Member

Docs need to be update in the ChainRules package

@YingboMa YingboMa reopened this Jan 16, 2020
@shashi
Copy link
Collaborator

shashi commented Jan 16, 2020

It turns out we need change #88 to be of the form:

res = frule(f, x..., partials...)
if res !== nothing
    fx, pushforward = res
    partials = pushforward(Zero(), partials...)
end

We're starting to think about Taylor mode FD where we need to differentiate through pushforward. If we don't have pushforward as a separate function, then we'd have to differentiate a call to frule which also re-runs the primal computation.

@willtebbutt
Copy link
Member

Lets keep this closed and move discussion over to #67

@willtebbutt willtebbutt mentioned this issue Feb 16, 2020
@oxinabox oxinabox changed the title Seperating frule and pushforward prevents efficient solutions Seperating frule and pushforward prevents efficient solutions (fuse pushforward) Jan 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants