From de23f7521e0fe24a06381971876cb9f2d9472c25 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 21 Oct 2022 14:22:11 -0400 Subject: [PATCH 1/4] Add tspan --- src/systems/abstractsystem.jl | 1 + src/systems/diffeqs/abstractodesystem.jl | 3 ++- src/systems/diffeqs/odesystem.jl | 9 ++++++--- src/systems/diffeqs/sdesystem.jl | 12 ++++++++---- src/systems/discrete_system/discrete_system.jl | 13 +++++++++---- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 9e6cfed3b2..483b51640c 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -182,6 +182,7 @@ for prop in [:eqs :iv :states :ps + :tspan :var_to_name :ctrls :defaults diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index a3300e53f3..860685a2d3 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -680,7 +680,8 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs... ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) end -function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map, tspan, +function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [], + tspan = get_tspan(sys), parammap = DiffEqBase.NullParameters(); callback = nothing, check_length = true, diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 3192249d9c..02129fe679 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -41,6 +41,8 @@ struct ODESystem <: AbstractODESystem states::Vector """Parameter variables. Must not contain the independent variable.""" ps::Vector + """Time span.""" + tspan::Union{NTuple{2, Any}, Nothing} """Array variables.""" var_to_name::Any """Control parameters (some subset of `ps`).""" @@ -125,7 +127,7 @@ struct ODESystem <: AbstractODESystem """ complete::Bool - function ODESystem(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, + function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching, connector_type, preface, cevents, devents, metadata = nothing, tearing_state = nothing, @@ -140,7 +142,7 @@ struct ODESystem <: AbstractODESystem if checks == true || (checks & CheckUnits) > 0 all_dimensionless([dvs; ps; iv]) || check_units(deqs) end - new(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, + new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching, connector_type, preface, cevents, devents, metadata, tearing_state, substitutions, complete) @@ -151,6 +153,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; controls = Num[], observed = Equation[], systems = ODESystem[], + tspan = nothing, name = nothing, default_u0 = Dict(), default_p = Dict(), @@ -195,7 +198,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; cont_callbacks = SymbolicContinuousCallbacks(continuous_events) disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), - deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, + deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connector_type, preface, cont_callbacks, disc_callbacks, metadata, checks = checks) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index d1bab96cd3..292cfc865e 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -42,6 +42,8 @@ struct SDESystem <: AbstractODESystem states::Vector """Parameter variables. Must not contain the independent variable.""" ps::Vector + """Time span.""" + tspan::Union{NTuple{2, Any}, Nothing} """Array variables.""" var_to_name::Any """Control parameters (some subset of `ps`).""" @@ -110,7 +112,8 @@ struct SDESystem <: AbstractODESystem """ complete::Bool - function SDESystem(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, + function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, + tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents, metadata = nothing, complete = false; @@ -124,7 +127,7 @@ struct SDESystem <: AbstractODESystem if checks == true || (checks & CheckUnits) > 0 all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs) end - new(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, + new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents, metadata, complete) @@ -135,6 +138,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; controls = Num[], observed = Num[], systems = SDESystem[], + tspan = nothing, default_u0 = Dict(), default_p = Dict(), defaults = _merge(Dict(default_u0), Dict(default_p)), @@ -177,7 +181,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), - deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, + deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, cont_callbacks, disc_callbacks, metadata; checks = checks) end @@ -531,7 +535,7 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...) SDEFunctionExpr{true}(sys, args...; kwargs...) end -function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan, +function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspan(sys), parammap = DiffEqBase.NullParameters(); sparsenoise = nothing, check_length = true, callback = nothing, kwargs...) where {iip} diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 1d0b3fea87..9ef9b4469d 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -37,6 +37,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem states::Vector """Parameter variables. Must not contain the independent variable.""" ps::Vector + """Time span.""" + tspan::Union{NTuple{2, Any}, Nothing} """Array variables.""" var_to_name::Any """Control parameters (some subset of `ps`).""" @@ -81,7 +83,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem """ complete::Bool - function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, + function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls, + observed, name, systems, defaults, preface, connector_type, metadata = nothing, @@ -94,7 +97,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem if checks == true || (checks & CheckUnits) > 0 all_dimensionless([dvs; ps; iv; ctrls]) || check_units(discreteEqs) end - new(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, + new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, name, + systems, defaults, preface, connector_type, metadata, tearing_state, substitutions, complete) end @@ -109,6 +113,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; controls = Num[], observed = Num[], systems = DiscreteSystem[], + tspan = nothing, name = nothing, default_u0 = Dict(), default_p = Dict(), @@ -142,7 +147,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; throw(ArgumentError("System names must be unique.")) end DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), - eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, + eqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, name, systems, defaults, preface, connector_type, metadata, kwargs...) end @@ -192,7 +197,7 @@ end Generates an DiscreteProblem from an DiscreteSystem. """ -function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map, tspan, +function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map = [], tspan = get_tspan(sys), parammap = SciMLBase.NullParameters(); eval_module = @__MODULE__, eval_expression = true, From 444d609543052d105ee5a334d65df61b6645e751 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 21 Oct 2022 14:30:08 -0400 Subject: [PATCH 2/4] Add tspan tests --- test/discretesystem.jl | 4 ++-- test/odesystem.jl | 5 +++-- test/sdesystem.jl | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/discretesystem.jl b/test/discretesystem.jl index e1e5b11dbd..ffe5ecfde9 100644 --- a/test/discretesystem.jl +++ b/test/discretesystem.jl @@ -52,10 +52,10 @@ eqs2 = [D(S) ~ S - infection2, D(I) ~ I + infection2 - recovery2, D(R) ~ R + recovery2] -@named sys = DiscreteSystem(eqs2; controls = [β, γ]) +@named sys = DiscreteSystem(eqs2; controls = [β, γ], tspan) @test ModelingToolkit.defaults(sys) != Dict() -prob_map2 = DiscreteProblem(sys, [], tspan) +prob_map2 = DiscreteProblem(sys) sol_map2 = solve(prob_map, FunctionMap()); @test sol_map.u == sol_map2.u diff --git a/test/odesystem.jl b/test/odesystem.jl index b8ef48dc97..9a5b51fbd1 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -374,10 +374,11 @@ eqs = [D(D(x)) ~ -b / M * D(x) - k / M * x] ps = [M, b, k] default_u0 = [D(x) => 0.0, x => 10.0] default_p = [M => 1.0, b => 1.0, k => 1.0] -@named sys = ODESystem(eqs, t, [x], ps, defaults = [default_u0; default_p]) +@named sys = ODESystem(eqs, t, [x], ps; defaults = [default_u0; default_p], tspan) sys = ode_order_lowering(sys) -prob = ODEProblem(sys, [], tspan) +prob = ODEProblem(sys) sol = solve(prob, Tsit5()) +@test sol.t[end] == tspan[end] @test sum(abs, sol[end]) < 1 # check_eqs_u0 kwarg test diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 6fa6570758..557e41a004 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -36,6 +36,7 @@ solexpr = solve(eval(probexpr), SRIW1(), seed = 1) # Test no error @test_nowarn SDEProblem(de, nothing, (0, 10.0)) +@test_nowarn SDEProblem(de) noiseeqs_nd = [0.01*x 0.01*x*y 0.02*x*z σ 0.01*y 0.02*x*z From 9f81cd184329562bb2b82c27245517eefc46801b Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 21 Oct 2022 14:30:18 -0400 Subject: [PATCH 3/4] Docs --- src/systems/diffeqs/odesystem.jl | 2 +- src/systems/diffeqs/sdesystem.jl | 2 +- src/systems/discrete_system/discrete_system.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 02129fe679..835c8b82e9 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -19,7 +19,7 @@ eqs = [D(x) ~ σ*(y-x), D(y) ~ x*(ρ-z)-y, D(z) ~ x*y - β*z] -@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β]) +@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],tspan=(0, 1000.0)) ``` """ struct ODESystem <: AbstractODESystem diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 292cfc865e..214e2821d6 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -23,7 +23,7 @@ noiseeqs = [0.1*x, 0.1*y, 0.1*z] -@named de = SDESystem(eqs,noiseeqs,t,[x,y,z],[σ,ρ,β]) +@named de = SDESystem(eqs,noiseeqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) ``` """ struct SDESystem <: AbstractODESystem diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 9ef9b4469d..d1856ee1dc 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -19,7 +19,7 @@ eqs = [D(x) ~ σ*(y-x), D(y) ~ x*(ρ-z)-y, D(z) ~ x*y - β*z] -@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]) # or +@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) # or @named de = DiscreteSystem(eqs) ``` """ From d0a4fa530c23200472c7b5f386d82f094c8dac9d Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 21 Oct 2022 15:28:53 -0400 Subject: [PATCH 4/4] Fix tests --- test/sdesystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 557e41a004..454930531b 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -20,7 +20,7 @@ noiseeqs = [0.1 * x, @named sys = ODESystem(eqs, t, [x, y, z], [σ, ρ, β]) @test SDESystem(sys, noiseeqs, name = :foo) isa SDESystem -@named de = SDESystem(eqs, noiseeqs, t, [x, y, z], [σ, ρ, β]) +@named de = SDESystem(eqs, noiseeqs, t, [x, y, z], [σ, ρ, β], tspan = (0.0, 10.0)) f = eval(generate_diffusion_function(de)[1]) @test f(ones(3), rand(3), nothing) == 0.1ones(3) @@ -36,7 +36,7 @@ solexpr = solve(eval(probexpr), SRIW1(), seed = 1) # Test no error @test_nowarn SDEProblem(de, nothing, (0, 10.0)) -@test_nowarn SDEProblem(de) +@test SDEProblem(de, nothing).tspan == (0.0, 10.0) noiseeqs_nd = [0.01*x 0.01*x*y 0.02*x*z σ 0.01*y 0.02*x*z