You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Not entirely sure if this should fall under DifferentialEquations.jl but here we go.
I am implementing a NeuralODE using callbacks myself using DifferentialEquations (not using DiffEqFlux atm).
I am noticing some strange behaviour with time steps being send to the callback function when calculating the gradient of the model with respect to a set of observerations.
Essentially I have a callback setting an a parameter (lets call it p) from 0 to a specific value and back to 0.
As an example: at t < 0: p = 0, a tstop at t = 0 is triggered and p = 100, then after some interval, for example t=1, another tstop is triggered and t = 1: p = 0.
For most evaluations, the gradient calculates fine.
I found that sometimes however, depending on the values of the other parameters, the solve call adds a timestep between my tstops, i.e. between t = 0 and t = 1, and as a result gradient calculation using Zygote errors. I found that this is due to the actual integrator time that is getting passed to the affect! function of my callback.
The condition() and affect!() are as follows (pseudo-code):
The error originates from the affect! function. Here, instead of calling this function with my tstops, the affect! function is called with the timestep in between my tstops during gradient calculation. By adding a println in the affect function I can see the following happening:
functionaffect!(integrator)
println("From affect!: integrator.t = ", integrator.t)
iffindfirst(isequal(integrator.t), tstops) ===nothingreturnend# exit the function to prevent error
integration.p[end] = values[findfirst(isequal(integrator.t), tstops)]
end
When no error occurs:
julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.0, 0.0, KpU, 1, Caa)
From affect!: integrator.t = TrackedReal<6Zy>(0.0, 0.0, 2dM, 1, Jmq)
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)
When the error occurs:
julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.5, 0.0, KpU, 1, Caa) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<6Zy>(0.5, 0.0, 2dM, 1, Jmq) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)
What is happening here? Hope someone can guide me to the correct location.
I can see if I can create a MVP if this doesn't immediately ring any bells.
The text was updated successfully, but these errors were encountered:
@frankschae can you take a look at this? I just found it on the bottom of my todo email and realized I never looked into it. But I know you've done some improvements to this part of the adjoints recently so maybe this just works?
Not entirely sure if this should fall under DifferentialEquations.jl but here we go.
I am implementing a NeuralODE using callbacks myself using DifferentialEquations (not using DiffEqFlux atm).
I am noticing some strange behaviour with time steps being send to the callback function when calculating the gradient of the model with respect to a set of observerations.
Essentially I have a callback setting an a parameter (lets call it
p
) from 0 to a specific value and back to 0.As an example: at
t < 0: p = 0
, a tstop at t = 0 is triggered andp = 100
, then after some interval, for example t=1, another tstop is triggered andt = 1: p = 0
.For most evaluations, the gradient calculates fine.
I found that sometimes however, depending on the values of the other parameters, the solve call adds a timestep between my tstops, i.e. between t = 0 and t = 1, and as a result gradient calculation using Zygote errors. I found that this is due to the actual integrator time that is getting passed to the affect! function of my callback.
The condition() and affect!() are as follows (pseudo-code):
The error originates from the affect! function. Here, instead of calling this function with my tstops, the affect! function is called with the timestep in between my tstops during gradient calculation. By adding a println in the affect function I can see the following happening:
When no error occurs:
When the error occurs:
What is happening here? Hope someone can guide me to the correct location.
I can see if I can create a MVP if this doesn't immediately ring any bells.
The text was updated successfully, but these errors were encountered: