From 14681269f49e36b9e78baf1bf252c7954cb7b03c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 8 Sep 2022 11:05:45 -1000 Subject: [PATCH 1/3] More robust method checking before wrapping Fixes https://github.com/SciML/DiffEqBase.jl/issues/821 . The key is that promote_type ends up calling promote_rule, so it always has a method. We need to check if the promotion rules exist with duals to know if this we can do the wrapping. In the future, we can do further customization of the wrappers based on if the solver does AD or not, though that would likely increase compile times by a lot (since then there would be multiple versions of the solver), so I'm not convinced that's the right approach after. At least this is always safe, if not "too safe" in some sense. --- src/solve.jl | 5 ++++- test/downstream/Project.toml | 1 + test/downstream/unwrapping.jl | 10 ++++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/solve.jl b/src/solve.jl index 4ce046fd1..fbc2a1306 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -968,7 +968,10 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize} RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) && one(t) === oneunit(t) && Tricks.static_hasmethod(ArrayInterfaceCore.promote_eltype, - Tuple{Type{typeof(u0)}, Type{eltype(u0)}})) || + Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && + Tricks.static_hasmethod(promote_rule,Tuple{Type{eltype(u0)},Type{dualgen(eltype(u0))}}) && + Tricks.static_hasmethod(promote_rule,Tuple{Type{eltype(u0)},Type{typeof(t)}}) + ) || (specialize === SciMLBase.FunctionWrapperSpecialize && !(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper))) return wrapfun_iip(f.f, (u0, u0, p, t)) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 0083c3a78..f73ad0206 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -4,6 +4,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa" ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" diff --git a/test/downstream/unwrapping.jl b/test/downstream/unwrapping.jl index 45f08eabc..662d49729 100644 --- a/test/downstream/unwrapping.jl +++ b/test/downstream/unwrapping.jl @@ -8,3 +8,13 @@ integrator = init(ode, Tsit5()) ode = ODEProblem(my_f!, [1.0], (0.0, 1.0)) integrator = init(ode, Tsit5()) @test SciMLBase.unwrapped_f(integrator.f.f) === my_f! + +using ForwardDiff, Measurements +x = 1.0 ± 0.0 +f = (du,u,p,t)-> du .= u +tspan = (ForwardDiff.Dual(0.0,(0.01)),ForwardDiff.Dual(1.0,(0.01))) +prob = ODEProblem(f, [x], tspan) + +# Should not error during problem construction but should be unwrapped +integ = init(prob, Tsit5(), dt = 0.1) +@test integ.f.f === f \ No newline at end of file From 9f3e6ed33af5a2fc0eb32b18a75d49fd394f95ab Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 8 Sep 2022 11:07:15 -1000 Subject: [PATCH 2/3] more test cases --- test/downstream/unwrapping.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/downstream/unwrapping.jl b/test/downstream/unwrapping.jl index 662d49729..e98ceb414 100644 --- a/test/downstream/unwrapping.jl +++ b/test/downstream/unwrapping.jl @@ -9,9 +9,16 @@ ode = ODEProblem(my_f!, [1.0], (0.0, 1.0)) integrator = init(ode, Tsit5()) @test SciMLBase.unwrapped_f(integrator.f.f) === my_f! -using ForwardDiff, Measurements +using OrdinaryDiffEq, ForwardDiff, Measurements x = 1.0 ± 0.0 f = (du,u,p,t)-> du .= u +tspan = (0.0,1.0) +prob = ODEProblem(f, [x], tspan) + +# Should not error during problem construction but should be unwrapped +integ = init(prob, Tsit5(), dt = 0.1) +@test integ.f.f === f + tspan = (ForwardDiff.Dual(0.0,(0.01)),ForwardDiff.Dual(1.0,(0.01))) prob = ODEProblem(f, [x], tspan) From f9af71c3abd2c722647ca2a15573753dfc4a8b9c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 8 Sep 2022 11:17:10 -1000 Subject: [PATCH 3/3] format --- src/solve.jl | 7 ++++--- test/downstream/unwrapping.jl | 8 ++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index fbc2a1306..8f5c2b824 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -969,9 +969,10 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize} one(t) === oneunit(t) && Tricks.static_hasmethod(ArrayInterfaceCore.promote_eltype, Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && - Tricks.static_hasmethod(promote_rule,Tuple{Type{eltype(u0)},Type{dualgen(eltype(u0))}}) && - Tricks.static_hasmethod(promote_rule,Tuple{Type{eltype(u0)},Type{typeof(t)}}) - ) || + Tricks.static_hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && + Tricks.static_hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{typeof(t)}})) || (specialize === SciMLBase.FunctionWrapperSpecialize && !(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper))) return wrapfun_iip(f.f, (u0, u0, p, t)) diff --git a/test/downstream/unwrapping.jl b/test/downstream/unwrapping.jl index e98ceb414..2fc871f2b 100644 --- a/test/downstream/unwrapping.jl +++ b/test/downstream/unwrapping.jl @@ -11,17 +11,17 @@ integrator = init(ode, Tsit5()) using OrdinaryDiffEq, ForwardDiff, Measurements x = 1.0 ± 0.0 -f = (du,u,p,t)-> du .= u -tspan = (0.0,1.0) +f = (du, u, p, t) -> du .= u +tspan = (0.0, 1.0) prob = ODEProblem(f, [x], tspan) # Should not error during problem construction but should be unwrapped integ = init(prob, Tsit5(), dt = 0.1) @test integ.f.f === f -tspan = (ForwardDiff.Dual(0.0,(0.01)),ForwardDiff.Dual(1.0,(0.01))) +tspan = (ForwardDiff.Dual(0.0, (0.01)), ForwardDiff.Dual(1.0, (0.01))) prob = ODEProblem(f, [x], tspan) # Should not error during problem construction but should be unwrapped integ = init(prob, Tsit5(), dt = 0.1) -@test integ.f.f === f \ No newline at end of file +@test integ.f.f === f