You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The Nested AD used in DiffEqFlux is not ideal because it calls ForwardDiff.gradient/jacobian and Zygote overrides them to compute the Hessian before doing the HVP.
BatchedRoutines.jl has routines to do this efficiently, we should start migrating the code to use that.
Note that this cannot be upstreamed to Zygote, because it requires capturing a different gradient / jacobian call to compute $\frac{\partial}{\partial p}\left(\frac{\partial f}{\partial u}\right)^T v$. Capturing the ForwardDiff calls only allows us to override $\frac{\partial}{\partial u}\left(\frac{\partial f}{\partial u}\right)^T v$.
(I will probably do the migration myself over the summer if no one else picks it up.)
The text was updated successfully, but these errors were encountered:
Update on this. LuxDL/Lux.jl#598 will handle everything automatically, so relying on another package is unnecessary. We just need to change the Zygote.gradient(f, x, ps) calls to Zygote.gradient(::StatefulLuxLayer, x)
The Nested AD used in DiffEqFlux is not ideal because it calls
ForwardDiff.gradient/jacobian
and Zygote overrides them to compute the Hessian before doing the HVP.BatchedRoutines.jl has routines to do this efficiently, we should start migrating the code to use that.
Note that this cannot be upstreamed to Zygote, because it requires capturing a different gradient / jacobian call to compute$\frac{\partial}{\partial p}\left(\frac{\partial f}{\partial u}\right)^T v$ . Capturing the ForwardDiff calls only allows us to override $\frac{\partial}{\partial u}\left(\frac{\partial f}{\partial u}\right)^T v$ .
(I will probably do the migration myself over the summer if no one else picks it up.)
The text was updated successfully, but these errors were encountered: