diff --git a/src/solve.jl b/src/solve.jl index 4ce046fd1..8f5c2b824 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -968,7 +968,11 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize} RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) && one(t) === oneunit(t) && Tricks.static_hasmethod(ArrayInterfaceCore.promote_eltype, - Tuple{Type{typeof(u0)}, Type{eltype(u0)}})) || + Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && + Tricks.static_hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && + Tricks.static_hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{typeof(t)}})) || (specialize === SciMLBase.FunctionWrapperSpecialize && !(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper))) return wrapfun_iip(f.f, (u0, u0, p, t)) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 0083c3a78..f73ad0206 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -4,6 +4,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa" ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" diff --git a/test/downstream/unwrapping.jl b/test/downstream/unwrapping.jl index 45f08eabc..2fc871f2b 100644 --- a/test/downstream/unwrapping.jl +++ b/test/downstream/unwrapping.jl @@ -8,3 +8,20 @@ integrator = init(ode, Tsit5()) ode = ODEProblem(my_f!, [1.0], (0.0, 1.0)) integrator = init(ode, Tsit5()) @test SciMLBase.unwrapped_f(integrator.f.f) === my_f! + +using OrdinaryDiffEq, ForwardDiff, Measurements +x = 1.0 ± 0.0 +f = (du, u, p, t) -> du .= u +tspan = (0.0, 1.0) +prob = ODEProblem(f, [x], tspan) + +# Should not error during problem construction but should be unwrapped +integ = init(prob, Tsit5(), dt = 0.1) +@test integ.f.f === f + +tspan = (ForwardDiff.Dual(0.0, (0.01)), ForwardDiff.Dual(1.0, (0.01))) +prob = ODEProblem(f, [x], tspan) + +# Should not error during problem construction but should be unwrapped +integ = init(prob, Tsit5(), dt = 0.1) +@test integ.f.f === f