-
-
Notifications
You must be signed in to change notification settings - Fork 80
Ito sense for SDE adjoints #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
add Ito sense to SDE adjoints
|
All tests in |
|
Not sure if there is still one conceptual issue (maybe related to the SOSRI issue in SciML/DiffEqNoiseProcess.jl#62) The reversion using EM for the 1D Ito SDE should be okay according to what I saw in: SciML/DiffEqNoiseProcess.jl#62 These tests are also all fine: res_sde_u0a, res_sde_pa = adjoint_sensitivities(sol,EM(),dg!,Array(t)
,dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP()))
@info res_sde_u0a, res_sde_pa
res_sde_u0b, res_sde_pb = adjoint_sensitivities(sol,EM(),dg!,Array(t)
,dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false))
@info res_sde_u0b, res_sde_pb
@test isapprox(res_sde_u0a, res_sde_u0b, rtol=1e-9)
@test isapprox(res_sde_pa, res_sde_pb, rtol=1e-9)
res_sde_u0b, res_sde_pb = adjoint_sensitivities(sol,EM(),dg!,Array(t)
,dt=dtscalar,adaptive=false,sensealg=BacksolveAdjoint(autojacvec=false), corfunc_analytical=corfunc)
@info res_sde_u0b, res_sde_pb
@test isapprox(res_sde_u0a, res_sde_u0b, rtol=1e-9)
@test isapprox(res_sde_pa, res_sde_pb, rtol=1e-9)That should imply that the correction factor is correctly implemented since ReverseDiff over Zygote, ForwardDiff over Zygote, and also ForwardDiff over analytical correction factor all coincide. However the gradients do not yet agree with the analytical / ForwardDiff result. |

@ChrisRackauckas Could you check if the way to add the Ito sense goes in the right direction?
I thought an additional type for the transformed function with its own vips are the best way to do it. However, I couldn't resolve up to now all errors in the vjps from the SensitivityFunction. (ReverseDiff is the only one which I could make running through the code but also there the gradients look slightly off.. )