Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong integrator.t (t not in tstops) send to callback during Zygote gradient calculation. #953

Open
Janssena opened this issue Apr 24, 2023 · 1 comment
Assignees

Comments

@Janssena
Copy link

Not entirely sure if this should fall under DifferentialEquations.jl but here we go.

I am implementing a NeuralODE using callbacks myself using DifferentialEquations (not using DiffEqFlux atm).
I am noticing some strange behaviour with time steps being send to the callback function when calculating the gradient of the model with respect to a set of observerations.

Essentially I have a callback setting an a parameter (lets call it p) from 0 to a specific value and back to 0.
As an example: at t < 0: p = 0, a tstop at t = 0 is triggered and p = 100, then after some interval, for example t=1, another tstop is triggered and t = 1: p = 0.
For most evaluations, the gradient calculates fine.
I found that sometimes however, depending on the values of the other parameters, the solve call adds a timestep between my tstops, i.e. between t = 0 and t = 1, and as a result gradient calculation using Zygote errors. I found that this is due to the actual integrator time that is getting passed to the affect! function of my callback.

The condition() and affect!() are as follows (pseudo-code):

values = [100, 0]
tstops = [0, 1]
condition(u, t, p) = t in tstops
affect!(integrator) = integration.p[end] = values[findfirst(isequal(integrator.t), tstops)]

The error originates from the affect! function. Here, instead of calling this function with my tstops, the affect! function is called with the timestep in between my tstops during gradient calculation. By adding a println in the affect function I can see the following happening:

function affect!(integrator)
  println("From affect!: integrator.t = ", integrator.t)
  if findfirst(isequal(integrator.t), tstops) === nothing
    return 
  end # exit the function to prevent error
  integration.p[end] = values[findfirst(isequal(integrator.t), tstops)]
end

When no error occurs:

julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.0, 0.0, KpU, 1, Caa)
From affect!: integrator.t = TrackedReal<6Zy>(0.0, 0.0, 2dM, 1, Jmq)
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)

When the error occurs:

julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.5, 0.0, KpU, 1, Caa) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<6Zy>(0.5, 0.0, 2dM, 1, Jmq) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)

What is happening here? Hope someone can guide me to the correct location.
I can see if I can create a MVP if this doesn't immediately ring any bells.

@ChrisRackauckas
Copy link
Member

@frankschae can you take a look at this? I just found it on the bottom of my todo email and realized I never looked into it. But I know you've done some improvements to this part of the adjoints recently so maybe this just works?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants