Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added checkpointing to gauss adjoint #884

Closed
wants to merge 16 commits into from
Closed

Conversation

acoh64
Copy link
Contributor

@acoh64 acoh64 commented Aug 26, 2023

See: #869. I think this needs some more testing first. It only seems to be working at lower tolerances (1e-4), so I need to figure out why

Copy link

@ai-maintainer ai-maintainer bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AI-Maintainer Review for PR - Added checkpointing to gauss adjoint

Title and Description 👍

The Title and Description are clear and focused
The title and description of the pull request are clear and focused. They effectively communicate the purpose of the changes, which is to add checkpointing to the Gauss adjoint algorithm. The author also acknowledges the need for further testing, particularly at higher tolerances.

Scope of Changes 👍

The changes are narrowly focused
The changes in this pull request are narrowly focused on adding checkpointing to the Gauss adjoint algorithm. The modifications are primarily in the `gauss_adjoint.jl` file, with some changes in the `sensitivity_algorithms.jl` and `sensitivity_interface.jl` files to support the new feature. The author is not trying to resolve multiple issues simultaneously.

Testing ⚠️

Testing details are not provided
The description does not provide specific details about how the author tested the changes. It would be helpful for the author to provide more information about the testing process, such as the test environment, test inputs, and expected outcomes, to ensure that the changes have been thoroughly tested and validated.

Documentation ⚠️

Docstrings are missing for some functions, classes, or methods
The following functions, classes, or methods do not have docstrings:
  • GaussCheckpointSolution
  • Gaussfindcursor
  • Gaussreset_p
  • GaussAdjoint
  • setvjp
  • adjoint_sensitivities

These entities should have docstrings added to describe their behavior, arguments, and return values.

Suggested Changes

  • Please add docstrings to the GaussCheckpointSolution, Gaussfindcursor, Gaussreset_p, GaussAdjoint, setvjp, and adjoint_sensitivities entities to describe their behavior, arguments, and return values.
  • Please provide more information about how you tested the changes, including the test environment, test inputs, and expected outcomes.

Reviewed with AI Maintainer

@acoh64 acoh64 changed the title added checkpointing to gauss adjoint, from https://github.com/SciML/SciMLSensitivity.jl/issues/869 added checkpointing to gauss adjoint Aug 26, 2023
@ChrisRackauckas
Copy link
Member

@avik-pal would you know where that is from?

@avik-pal
Copy link
Member

Not really. Core2 tests are passing in other CI https://github.com/SciML/SciMLSensitivity.jl/actions/runs/6001815834/job/16276828862?pr=885. Also I can't find any .t in code, and the stacktrace seems incomplete?

@ChrisRackauckas
Copy link
Member

It only seems to be working at lower tolerances (1e-4), so I need to figure out why

Do the other ones do this too? We only have low tolerance tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue is here. I removed checkpoints from kwargs because of an error that :checkpoints is an unrecognized kwarg for solve. However, I just realized that I don't think the correct checkpoints are being used now since checkpoints is not being passed to anything

@acoh64
Copy link
Contributor Author

acoh64 commented Aug 30, 2023

The other ones checkpoint tests (for interpolating adjoint) work at a tolerance of 1e-9.

@acoh64
Copy link
Contributor Author

acoh64 commented Aug 30, 2023

Ah I think I see now

@acoh64
Copy link
Contributor Author

acoh64 commented Aug 30, 2023

I fixed errors in the adjoint state that came from not passing the proper tolerances into the checkpointing solves. However, the gradient calculation still only works at tolerances of 1e-4

@ChrisRackauckas
Copy link
Member

Is the solve dense and not using saveat?

@acoh64
Copy link
Contributor Author

acoh64 commented Aug 30, 2023

The checkpoint solves are dense but the adjoint solve is not

@codecov
Copy link

codecov bot commented Aug 30, 2023

Codecov Report

Merging #884 (9ca3932) into master (5d03a76) will decrease coverage by 13.47%.
Report is 11 commits behind head on master.
The diff coverage is 0.32%.

❗ Current head 9ca3932 differs from pull request most recent head f1cd78c. Consider uploading reports for the commit f1cd78c to get more accurate results

@@             Coverage Diff             @@
##           master     #884       +/-   ##
===========================================
- Coverage   61.91%   48.45%   -13.47%     
===========================================
  Files          20       20               
  Lines        4385     4658      +273     
===========================================
- Hits         2715     2257      -458     
- Misses       1670     2401      +731     
Files Coverage Δ
src/concrete_solve.jl 65.88% <100.00%> (-3.50%) ⬇️
src/interpolating_adjoint.jl 71.02% <ø> (-4.99%) ⬇️
src/sensitivity_interface.jl 83.33% <ø> (-8.34%) ⬇️
src/derivative_wrappers.jl 77.97% <0.00%> (-12.23%) ⬇️
src/sensitivity_algorithms.jl 70.00% <0.00%> (-7.78%) ⬇️
src/gauss_adjoint.jl 0.00% <0.00%> (-65.70%) ⬇️

... and 6 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@acoh64
Copy link
Contributor Author

acoh64 commented Sep 13, 2023

Which solve were you referring to? When I use the dense forward solve with checkpointing, then GaussAdjoint works to 1e-9 tolerance, but this feels like cheating.

The issue is in the forward solves, since the adjoint solution matches perfectly and the callback times line up

@acoh64
Copy link
Contributor Author

acoh64 commented Sep 14, 2023

@ChrisRackauckas I figured out the problem: the integrand function is made ahead of time with the solution to the nondense solve. Therefore, it is using the parameter jacobians from the nondense solve in the integral calculation, causing it to be less accurate. I think the best solution is to make a callback which takes as an argument the integrator and a sol, however, I don't think this can be done with the current callback interface. Is there a way around this?

@ChrisRackauckas
Copy link
Member

Add the GaussIntegrand to the ODEGaussAdjointSensitivityFunction and make it a mutable struct so you can mutate integrand.sol when a new cpsol is taken.

@acoh64
Copy link
Contributor Author

acoh64 commented Sep 18, 2023

Thanks, checkpointing now works for GaussAdjoint!

@ChrisRackauckas
Copy link
Member

Looks like there's still a few test failures?

@acoh64
Copy link
Contributor Author

acoh64 commented Sep 20, 2023

Checkpointing tests pass but I am still working on the SDE tests

Copy link
Member

@frankschae frankschae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glatteis @acoh64's GaussAdjoint will probably also solve the efficiency problem you spotted. I think we can add in the callback

dy2, back = Zygote.pullback(y, p) do u, p
    g(u, p, t) .* dW 
end
out2 = back(λ)

test/adjoint.jl Outdated



using Lux, Optimization, Plots, Random
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not in the right place?

test/adjoint.jl Outdated
Comment on lines 33 to 60
function dudt(u, p, t)
global st
#input_val = u_vals[Int(round(t*10)+1)]
out, st = nn_model(vcat(u[1], ex[Int(round(10 * 0.1))]), p, st)
return out
end

prob = ODEProblem(dudt, u0, tspan, nothing)

function predict_neuralode(p)
_prob = remake(prob, p = p)
Array(solve(_prob, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-6, sensealg=GaussAdjoint(autojacvec=ZygoteVJP())))
end

function loss(p)
sol = predict_neuralode(p)
N = length(sol)
return sum(abs2.(y[1:N] .- sol')) / N
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p_model)

tmp1 = Zygote.gradient(loss,ComponentArray(p_model))
tmp2 = Zygote.gradient(loss,p_model)

res0 = Optimization.solve(optprob, PolyOpt(), maxiters = 100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test just the adjoint

@acoh64
Copy link
Contributor Author

acoh64 commented Sep 29, 2023

@ChrisRackauckas I removed the SDE stuff so this should be ready to go with checkpointing and nonvector parameter implementations

@ChrisRackauckas
Copy link
Member

It looks like there's two test failures still seen, one in the docs and one in core 3.

@avik-pal avik-pal linked an issue Oct 3, 2023 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Compatibility with Functors and non-vector parameters in Gauss Adjoint
4 participants