-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Seperating frule and pushforward prevents efficient solutions (fuse pushforward) #74
Comments
If we exposed the seed as an argument to the outer function, And avoiding creating the closure would likely make the compiler work better. |
To break down the ODE example a bit more and why we can't solve it with current If we wanted to do
right now we have two options:
Both are terrible. |
What should the API be?
|
The API that I implemented in ForwardDiff2.jl is julia> function ChainRulesCore.frule(::typeof(cumsum), x::AbstractVector, dx)
function f_pushforward(x, dx)
return cumsum(x), cumsum(dx, dims=1)
end
end
julia> a, b = rand(5), rand(5); f_pushforward = frule(cumsum, a, b); f_pushforward(a, b)
([0.0025809013385238444, 0.7648493671053664, 1.1185434888949843, 1.8270179446376797, 2.350927101459723], [0.24271532595247858, 0.8825882669221625, 1.3966390333846403, 2.07871296820404, 3.015967442534456])
julia> a, b = rand(5), rand(5, 5); f_pushforward(a, b)
([0.9997930779816582, 1.5638602684445428, 2.0694747840930843, 2.895235231851567, 3.693583667087273], [0.5907283915879293 0.1977100837168979 … 0.23178849580431904 0.7769216080950119; 1.1754562291522375 0.7468940097801302 … 0.47644964933541734 1.1340193375193208; … ; 2.3833577218064717 2.4033854494631264 … 1.0361063668932011 1.8776613309790917; 2.5465035717134303 3.1971200202671013 … 1.467669710763345 2.6282230037188343]) It makes the implementation easier because there is no enclosed values, and I don't need to |
I am really not a fan of returning a function with exact same inputs as the original. |
Maybe just julia> function ChainRulesCore.frule(::typeof(cumsum), x::AbstractVector, dx)
return cumsum(x), cumsum(dx, dims=1)
end |
Should we wrap the |
I think it's pretty clear from these discussions that the way forwards in terms of what we return is just the output - no closures over anything. In terms of the inputs, you could use a dual-numbers inspired approach whereby each input becomes a tuple containing its primal and differential. So you could have something like function frule(::typeof(foo), (x, dx), (y, dy), ...; kwargs...)
return
end This is what we currently do in It's worth reiterating that, as @oxinabox pointed out, this means that our expression problem #53 isn't an issue for forward rules. |
I do like @willtebbutt 's solution. And I suspect it will be easier for ForwardDiff2 purposes? It does however mean we would have to write type signatures with Tuple's everywhere.
But macros could fix that up. (#44) @YingboMa do you have preference between yours and Will's proposal? |
actually the type-signature would be:
I do kinda like how it pairs the |
I think you might be meant |
No? Or in the most common case when |
@YingboMa do you plan on doing this change? |
Don't hold 0.5 for this. I don't have much time recently. |
I will do it this weekend. |
We never actually agreed on the desired resolution to this issue. Implemented in #88 is the solution proposed above by @YingboMa , which was merged rather hastily over the weekend. @YingboMa @ChrisRackauckas : are either of you especially tied to the implemented solution? FWIW I much prefer the one I proposed -- seems less error prone and has worked well thus far in FiniteDifferences. |
I don't think that is a good API for AD implementers. One would have to zip the primals and partials. |
Why would one need to For
Its:
vs:
I don't see the zip in either? |
If you have |
Also, I am no super thrilled about dispatching on |
Right. I think thats a good argument. |
Okay, thanks for the explanation @YingboMa . Seems reasonable to me :) |
Docs need to be update in the ChainRules package |
It turns out we need change #88 to be of the form: res = frule(f, x..., partials...)
if res !== nothing
fx, pushforward = res
partials = pushforward(Zero(), partials...)
end We're starting to think about Taylor mode FD where we need to differentiate through pushforward. If we don't have pushforward as a separate function, then we'd have to differentiate a call to |
Lets keep this closed and move discussion over to #67 |
When implementing the frules, I realized that the implementation of
frule
doesn't allow for standard optimizations that are seen in forward rule implementations. The reason is because the way that forward mode works is that it propagates the derivative along step by step. A good primer on all of this is this set of notes:https://mitmath.github.io/18337/lecture9/autodiff_dimensions
Essentially what we are trying to do with ChainRules.jl is allow the user to describe how to calculate
f(x)
andf'(x)v
, the primal and the jvp. Currently the formulation is:where
pushforward
ispushforward(dargs)
. However, given that discussion of forward mode differentiation, one can see that this runs contrary to how it is actually calculated. Here's two examples of itExample 1: Implementing ForwardDiff over frules
As described in the notes, the dual number way of computing forward mode starts by seeding dual numbers. In standard ForwardDiff usage, these seeds are all unique basis vectors, like is shown in the DiffEq documentation for how to AD through the solver manually:
https://docs.juliadiffeq.org/v6.8/analysis/sensitivity/#Examples-using-ForwardDiff.jl-1
But as mentioned in the notes, what this is really doing is seeding the duals in the basis vector
e_i
directions, so then the jvp is computingJ*e_1,J*e_2,J*e_3
as separate vectors, giving a representation of the full Jacobian. If you do get the whole Jacobian, then you can doJ*v
of course, and this is what the current `frule would allow:However, this shows that there is a more efficient way to calculate
y,dy*dx
though, since if we know thedx
at the start, we can just seed the dual numbers along the direction of ofdx
, which changes the number of dual dimensions fromlength(x)
to 1:This changes it from an O(n) computation to O(1)!
Example 2: Implementation of Forward Sensitivity Analysis for ODEs
Now here's a bit more concrete example for the user side. For ODEs, you want to look at:
and you want to know how u(t) changes w.r.t. p. So take the d/dp of both sides of the ODE and by the chain rule you get (swap integrals, assume nice properties)
calling
S = du/dp
, this is justSo you get another ODE that gives you the derivatives of the solution of the original ODE w.r.t parameters. This is the continuous pushforward rule! Now the difficulty is that you need to be able to calculate
(df/du)(t)
which requires that you knowu(t)
. Now in theory you could calculateu(t)
a continuous solution beforehand by solving the previous ODE and storing it, but that's not the good way to do it. The way you do it is just realize that, if you solve the ODE:together, then you always know
u
since it's the first part of the equation! So magic happens and this is very efficient.That's almost there. What sensitivities are we pushing forward though? You can seed the sensitivities from
S=0
and the outputS = du/dp
, but that's not satisfying. What if you wanted to knowdu/d(u0)
anddu/dp
? Sinceconcrete_solve(p,u0,odeprob,solver,...)
is a function of bothp
andu0
, we want the derivative of the ODE's solution with respect to thep
and theu0
.It turns out from simple math that all you have to do is set
S = du0
! So then, in "composed frule" notation, you'd do the following:Right now, this can't really be expressed.
API
Actually having those arguments might be difficult, so maybe it's easier to write as:
Anyways, the exact API is an interesting question, but whatever it is, the computation should have the
x
and thedx
at the same time.The text was updated successfully, but these errors were encountered: