-
-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added checkpointing to gauss adjoint #884
Conversation
… lower tolerances (1e-4)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AI-Maintainer Review for PR - Added checkpointing to gauss adjoint
Title and Description 👍
Scope of Changes 👍
Testing ⚠️
Documentation ⚠️
GaussCheckpointSolution
Gaussfindcursor
Gaussreset_p
GaussAdjoint
setvjp
adjoint_sensitivities
These entities should have docstrings added to describe their behavior, arguments, and return values.
Suggested Changes
- Please add docstrings to the
GaussCheckpointSolution
,Gaussfindcursor
,Gaussreset_p
,GaussAdjoint
,setvjp
, andadjoint_sensitivities
entities to describe their behavior, arguments, and return values. - Please provide more information about how you tested the changes, including the test environment, test inputs, and expected outcomes.
Reviewed with AI Maintainer
@avik-pal would you know where that is from? |
Not really. Core2 tests are passing in other CI https://github.com/SciML/SciMLSensitivity.jl/actions/runs/6001815834/job/16276828862?pr=885. Also I can't find any |
Do the other ones do this too? We only have low tolerance tests. |
src/sensitivity_interface.jl
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the issue is here. I removed checkpoints from kwargs because of an error that :checkpoints
is an unrecognized kwarg for solve
. However, I just realized that I don't think the correct checkpoints are being used now since checkpoints
is not being passed to anything
The other ones checkpoint tests (for interpolating adjoint) work at a tolerance of 1e-9. |
Ah I think I see now |
I fixed errors in the adjoint state that came from not passing the proper tolerances into the checkpointing solves. However, the gradient calculation still only works at tolerances of 1e-4 |
Is the solve dense and not using |
The checkpoint solves are dense but the adjoint solve is not |
Codecov Report
@@ Coverage Diff @@
## master #884 +/- ##
===========================================
- Coverage 61.91% 48.45% -13.47%
===========================================
Files 20 20
Lines 4385 4658 +273
===========================================
- Hits 2715 2257 -458
- Misses 1670 2401 +731
... and 6 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Which solve were you referring to? When I use the dense forward solve with checkpointing, then GaussAdjoint works to 1e-9 tolerance, but this feels like cheating. The issue is in the forward solves, since the adjoint solution matches perfectly and the callback times line up |
@ChrisRackauckas I figured out the problem: the integrand function is made ahead of time with the solution to the nondense solve. Therefore, it is using the parameter jacobians from the nondense solve in the integral calculation, causing it to be less accurate. I think the best solution is to make a callback which takes as an argument the integrator and a |
Add the GaussIntegrand to the ODEGaussAdjointSensitivityFunction and make it a mutable struct so you can mutate |
Thanks, checkpointing now works for GaussAdjoint! |
Looks like there's still a few test failures? |
Checkpointing tests pass but I am still working on the SDE tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/adjoint.jl
Outdated
|
||
|
||
|
||
using Lux, Optimization, Plots, Random |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not in the right place?
test/adjoint.jl
Outdated
function dudt(u, p, t) | ||
global st | ||
#input_val = u_vals[Int(round(t*10)+1)] | ||
out, st = nn_model(vcat(u[1], ex[Int(round(10 * 0.1))]), p, st) | ||
return out | ||
end | ||
|
||
prob = ODEProblem(dudt, u0, tspan, nothing) | ||
|
||
function predict_neuralode(p) | ||
_prob = remake(prob, p = p) | ||
Array(solve(_prob, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-6, sensealg=GaussAdjoint(autojacvec=ZygoteVJP()))) | ||
end | ||
|
||
function loss(p) | ||
sol = predict_neuralode(p) | ||
N = length(sol) | ||
return sum(abs2.(y[1:N] .- sol')) / N | ||
end | ||
|
||
adtype = Optimization.AutoZygote() | ||
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) | ||
optprob = Optimization.OptimizationProblem(optf, p_model) | ||
|
||
tmp1 = Zygote.gradient(loss,ComponentArray(p_model)) | ||
tmp2 = Zygote.gradient(loss,p_model) | ||
|
||
res0 = Optimization.solve(optprob, PolyOpt(), maxiters = 100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test just the adjoint
@ChrisRackauckas I removed the SDE stuff so this should be ready to go with checkpointing and nonvector parameter implementations |
It looks like there's two test failures still seen, one in the docs and one in core 3. |
See: #869. I think this needs some more testing first. It only seems to be working at lower tolerances (1e-4), so I need to figure out why