Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicate callbacks in ForwardDiff and ReverseDiff adjoints #1087

Merged
merged 14 commits into from
Aug 2, 2024
45 changes: 31 additions & 14 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,9 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem,
prob.f
end

_prob = ODEForwardSensitivityProblem(_f, u0, prob.tspan, p, sensealg)
# callback = nothing ensures only the callback in kwargs is used
_prob = ODEForwardSensitivityProblem(
_f, u0, prob.tspan, p, sensealg, callback = nothing)
sol = solve(_prob, alg, args...; kwargs...)
_, du = extract_local_sensitivities(sol, sensealg, Val(true))
ts = current_time(sol)
Expand Down Expand Up @@ -729,7 +731,8 @@ function DiffEqBase._concrete_solve_forward(prob::SciMLBase.AbstractODEProblem,
u0, p, originator::SciMLBase.ADOriginator,
args...; save_idxs = nothing,
kwargs...)
_prob = ODEForwardSensitivityProblem(prob.f, u0, prob.tspan, p, sensealg)
_prob = ODEForwardSensitivityProblem(
prob.f, u0, prob.tspan, p, sensealg, callback = nothing)
sol = solve(_prob, args...; kwargs...)
u, du = extract_local_sensitivities(sol, Val(true))
_save_idxs = save_idxs === nothing ? (1:length(u0)) : save_idxs
Expand Down Expand Up @@ -786,7 +789,9 @@ function DiffEqBase._concrete_solve_adjoint(
_saveat = saveat
end

sol = solve(remake(prob, p = p, u0 = u0), alg, args...; saveat = _saveat, kwargs...)
# use the callback in kwargs, not prob
sol = solve(remake(prob, p = p, u0 = u0, callback = nothing),
alg, args...; saveat = _saveat, kwargs...)

# saveat values
# need all values here. Not only unique ones.
Expand Down Expand Up @@ -864,7 +869,10 @@ function DiffEqBase._concrete_solve_adjoint(
else
_f = prob.f
end
_prob = remake(prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual)

# use the callback from kwargs, not prob
_prob = remake(prob, f = _f, u0 = u0dual, p = pdual,
tspan = tspandual, callback = nothing)

if _prob isa SDEProblem
_prob.noise_rate_prototype !== nothing && (_prob = remake(_prob,
Expand Down Expand Up @@ -1019,7 +1027,9 @@ function DiffEqBase._concrete_solve_adjoint(
_f = prob.f
end

_prob = remake(prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual)
# use the callback from kwargs, not prob
_prob = remake(prob, f = _f, u0 = u0dual, p = pdual,
tspan = tspandual, callback = nothing)

if _prob isa SDEProblem
_prob.noise_rate_prototype !== nothing && (_prob = remake(_prob,
Expand Down Expand Up @@ -1204,9 +1214,11 @@ function DiffEqBase._concrete_solve_adjoint(
(prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper ||
SciMLBase.specialization(prob.f) === SciMLBase.AutoSpecialize)
f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f))
_prob = remake(prob, f = f, u0 = map(identity, _u0), p = _p, tspan = _tspan)
_prob = remake(prob, f = f, u0 = map(identity, _u0),
p = _p, tspan = _tspan, callback = nothing)
else
_prob = remake(prob, u0 = map(identity, _u0), p = _p, tspan = _tspan)
_prob = remake(prob, u0 = map(identity, _u0), p = _p,
tspan = _tspan, callback = nothing)
end
else
# use TrackedArray for efficiency of the tape
Expand Down Expand Up @@ -1237,13 +1249,15 @@ function DiffEqBase._concrete_solve_adjoint(
SciMLBase.FullSpecialize
}(_f,
_g),
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p), tspan = _tspan)
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
tspan = _tspan, callback = nothing)
else
_prob = remake(prob,
f = DiffEqBase.parameterless_type(prob.f){false,
SciMLBase.FullSpecialize
}(_f),
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p), tspan = _tspan)
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
tspan = _tspan, callback = nothing)
end
elseif prob isa
Union{SciMLBase.AbstractODEProblem, SciMLBase.AbstractSDEProblem}
Expand All @@ -1269,13 +1283,15 @@ function DiffEqBase._concrete_solve_adjoint(
SciMLBase.FullSpecialize
}(_f,
_g),
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p), tspan = _tspan)
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
tspan = _tspan, callback = nothing)
else
_prob = remake(prob,
f = DiffEqBase.parameterless_type(prob.f){false,
SciMLBase.FullSpecialize
}(_f),
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p), tspan = _tspan)
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
tspan = _tspan, callback = nothing)
end
else
error("TrackerAdjont does not currently support the specified problem type. Please open an issue.")
Expand Down Expand Up @@ -1408,7 +1424,8 @@ function DiffEqBase._concrete_solve_adjoint(
f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f))
_prob = remake(prob, f = f, u0 = reshape([x for x in _u0], size(_u0)),
p = _p,
tspan = _tspan)
tspan = _tspan,
callback = nothing)
else
_prob = remake(prob, u0 = reshape([x for x in _u0], size(_u0)), p = _p,
tspan = _tspan)
Expand All @@ -1423,13 +1440,13 @@ function DiffEqBase._concrete_solve_adjoint(
SciMLBase.isinplace(prob),
true}(_f,
_g),
u0 = _u0, p = _p, tspan = _tspan)
u0 = _u0, p = _p, tspan = _tspan, callback = nothing)
else
_prob = remake(prob,
f = DiffEqBase.parameterless_type(prob.f){
SciMLBase.isinplace(prob),
true}(_f),
u0 = _u0, p = _p, tspan = _tspan)
u0 = _u0, p = _p, tspan = _tspan, callback = nothing)
end
end

Expand Down
34 changes: 34 additions & 0 deletions test/prob_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,37 @@ a = ones(3)
@test Zygote.gradient(f, a)[1][1] ≈ Zygote.gradient(f2, a)[1][1]
@test Zygote.gradient(f, a)[1][2] == Zygote.gradient(f2, a)[1][2] == 0
@test Zygote.gradient(f, a)[1][3] == Zygote.gradient(f2, a)[1][3] == 0

# callback in problem construction or in solve call should give same result
# https://github.com/SciML/SciMLSensitivity.jl/issues/1081
odef(du, u, p, t) = du .= u .* p
prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])

let callback_count1 = 0, callback_count2 = 0
function f1(u0p, adjoint_type)
condition(u, t, integrator) = t == 0.5
affect!(integrator) = callback_count1 += 1
cb = DiscreteCallback(condition, affect!)
prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2]; callback = cb)
sum(solve(prob, Tsit5(), tstops = [0.5], sensealg = adjoint_type))
end

function f2(u0p, adjoint_type)
condition(u, t, integrator) = t == 0.5
affect!(integrator) = callback_count2 += 1
cb = DiscreteCallback(condition, affect!)
prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2])
sum(solve(prob, Tsit5(), tstops = [0.5], callback = cb, sensealg = adjoint_type))
end

@testset "Callback duplication check" begin
for adjoint_type in [
ForwardDiffSensitivity(), ReverseDiffAdjoint(), TrackerAdjoint()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and loop through some adjoint methods too, all sensealgs is best

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except ZygoteAdjoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ZygoteAdjoint() and GaussAdjoint() error out, besides those I've included all the rest that support callbacks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the ones that error, label them as @test_broken.

Gauss adjoint with callbacks is currently on the critical Todo.

u0p = [2.0, 3.0]
Zygote.gradient(x -> f1(x, adjoint_type), u0p)
Zygote.gradient(x -> f2(x, adjoint_type), u0p)

@test callback_count1 == callback_count2
end
end
end
Loading