Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicate callbacks in ForwardDiff and ReverseDiff adjoints #1087

Merged
merged 14 commits into from
Aug 2, 2024

Conversation

jClugstor
Copy link
Contributor

@jClugstor jClugstor commented Jul 30, 2024

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Fixes #1081

In some of the Adjoint methods, if a problem that was constructed with a callback is passed to it, the callback was being called twice. This because the kwargs argument already contains the callback that was passed to the problem constructor. So when the problem is solved, it uses the callback in the problem itself, and the callback passed to the solve, effectively doing the same callback operation twice.

This PR just makes it so that only the callback in the kwargs is used.

@ChrisRackauckas
Copy link
Member

Needs tests, make the problem from the issue into a test case.

@ChrisRackauckas
Copy link
Member

Also handle ReverseDiffAdjoint() and TrackerAdjoint as mentioned in the issue.

@jClugstor
Copy link
Contributor Author

It errors out when using ReverseDiffAdjoint or TrackerAdjoint, since the callback tries to mutate the parameter array.

@ChrisRackauckas
Copy link
Member

Make some @test_broken and comment in the original issue. Maybe test another callback that doesn't require mutating?

@jClugstor
Copy link
Contributor Author

Turns out TrackerAdjoint had the same problem.
I made the tests check that the callbacks are called the same number of times, which is the root of the issue.


@testset "Callback duplication check" begin
for adjoint_type in [
ForwardDiffSensitivity(), ReverseDiffAdjoint(), TrackerAdjoint()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and loop through some adjoint methods too, all sensealgs is best

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except ZygoteAdjoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ZygoteAdjoint() and GaussAdjoint() error out, besides those I've included all the rest that support callbacks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the ones that error, label them as @test_broken.

Gauss adjoint with callbacks is currently on the critical Todo.

@ChrisRackauckas ChrisRackauckas merged commit 516831c into SciML:master Aug 2, 2024
12 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Including callback in problem vs solve call changes gradient for ForwardDiffSensitivity
2 participants