Skip to content

Commit

Permalink
Merge pull request #1898 from SciML/myb/tspan
Browse files Browse the repository at this point in the history
Add tspan to *System
  • Loading branch information
YingboMa committed Oct 21, 2022
2 parents 3149f77 + d0a4fa5 commit 00590bf
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ for prop in [:eqs
:iv
:states
:ps
:tspan
:var_to_name
:ctrls
:defaults
Expand Down
3 changes: 2 additions & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map, tspan,
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
callback = nothing,
check_length = true,
Expand Down
11 changes: 7 additions & 4 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ eqs = [D(x) ~ σ*(y-x),
D(y) ~ x*(ρ-z)-y,
D(z) ~ x*y - β*z]
@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β])
@named de = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],tspan=(0, 1000.0))
```
"""
struct ODESystem <: AbstractODESystem
Expand All @@ -41,6 +41,8 @@ struct ODESystem <: AbstractODESystem
states::Vector
"""Parameter variables. Must not contain the independent variable."""
ps::Vector
"""Time span."""
tspan::Union{NTuple{2, Any}, Nothing}
"""Array variables."""
var_to_name::Any
"""Control parameters (some subset of `ps`)."""
Expand Down Expand Up @@ -125,7 +127,7 @@ struct ODESystem <: AbstractODESystem
"""
complete::Bool

function ODESystem(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
torn_matching, connector_type, preface, cevents,
devents, metadata = nothing, tearing_state = nothing,
Expand All @@ -140,7 +142,7 @@ struct ODESystem <: AbstractODESystem
if checks == true || (checks & CheckUnits) > 0
all_dimensionless([dvs; ps; iv]) || check_units(deqs)
end
new(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
connector_type, preface, cevents, devents, metadata, tearing_state,
substitutions, complete)
Expand All @@ -151,6 +153,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Equation[],
systems = ODESystem[],
tspan = nothing,
name = nothing,
default_u0 = Dict(),
default_p = Dict(),
Expand Down Expand Up @@ -195,7 +198,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
connector_type, preface, cont_callbacks, disc_callbacks,
metadata, checks = checks)
Expand Down
14 changes: 9 additions & 5 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ noiseeqs = [0.1*x,
0.1*y,
0.1*z]
@named de = SDESystem(eqs,noiseeqs,t,[x,y,z],[σ,ρ,β])
@named de = SDESystem(eqs,noiseeqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0))
```
"""
struct SDESystem <: AbstractODESystem
Expand All @@ -42,6 +42,8 @@ struct SDESystem <: AbstractODESystem
states::Vector
"""Parameter variables. Must not contain the independent variable."""
ps::Vector
"""Time span."""
tspan::Union{NTuple{2, Any}, Nothing}
"""Array variables."""
var_to_name::Any
"""Control parameters (some subset of `ps`)."""
Expand Down Expand Up @@ -110,7 +112,8 @@ struct SDESystem <: AbstractODESystem
"""
complete::Bool

function SDESystem(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
tgrad,
jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
cevents, devents, metadata = nothing, complete = false;
Expand All @@ -124,7 +127,7 @@ struct SDESystem <: AbstractODESystem
if checks == true || (checks & CheckUnits) > 0
all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs)
end
new(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac,
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
metadata, complete)
Expand All @@ -135,6 +138,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
controls = Num[],
observed = Num[],
systems = SDESystem[],
tspan = nothing,
default_u0 = Dict(),
default_p = Dict(),
defaults = _merge(Dict(default_u0), Dict(default_p)),
Expand Down Expand Up @@ -177,7 +181,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)

SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
cont_callbacks, disc_callbacks, metadata; checks = checks)
end
Expand Down Expand Up @@ -531,7 +535,7 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
SDEFunctionExpr{true}(sys, args...; kwargs...)
end

function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map, tspan,
function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
sparsenoise = nothing, check_length = true,
callback = nothing, kwargs...) where {iip}
Expand Down
15 changes: 10 additions & 5 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ eqs = [D(x) ~ σ*(y-x),
D(y) ~ x*(ρ-z)-y,
D(z) ~ x*y - β*z]
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]) # or
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) # or
@named de = DiscreteSystem(eqs)
```
"""
Expand All @@ -37,6 +37,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
states::Vector
"""Parameter variables. Must not contain the independent variable."""
ps::Vector
"""Time span."""
tspan::Union{NTuple{2, Any}, Nothing}
"""Array variables."""
var_to_name::Any
"""Control parameters (some subset of `ps`)."""
Expand Down Expand Up @@ -81,7 +83,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
"""
complete::Bool

function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed,
function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls,
observed,
name,
systems, defaults, preface, connector_type,
metadata = nothing,
Expand All @@ -94,7 +97,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
if checks == true || (checks & CheckUnits) > 0
all_dimensionless([dvs; ps; iv; ctrls]) || check_units(discreteEqs)
end
new(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems,
new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, name,
systems,
defaults,
preface, connector_type, metadata, tearing_state, substitutions, complete)
end
Expand All @@ -109,6 +113,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Num[],
systems = DiscreteSystem[],
tspan = nothing,
name = nothing,
default_u0 = Dict(),
default_p = Dict(),
Expand Down Expand Up @@ -142,7 +147,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
throw(ArgumentError("System names must be unique."))
end
DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems,
eqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, name, systems,
defaults, preface, connector_type, metadata, kwargs...)
end

Expand Down Expand Up @@ -192,7 +197,7 @@ end
Generates an DiscreteProblem from an DiscreteSystem.
"""
function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map, tspan,
function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map = [], tspan = get_tspan(sys),
parammap = SciMLBase.NullParameters();
eval_module = @__MODULE__,
eval_expression = true,
Expand Down
4 changes: 2 additions & 2 deletions test/discretesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ eqs2 = [D(S) ~ S - infection2,
D(I) ~ I + infection2 - recovery2,
D(R) ~ R + recovery2]

@named sys = DiscreteSystem(eqs2; controls = [β, γ])
@named sys = DiscreteSystem(eqs2; controls = [β, γ], tspan)
@test ModelingToolkit.defaults(sys) != Dict()

prob_map2 = DiscreteProblem(sys, [], tspan)
prob_map2 = DiscreteProblem(sys)
sol_map2 = solve(prob_map, FunctionMap());

@test sol_map.u == sol_map2.u
Expand Down
5 changes: 3 additions & 2 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,11 @@ eqs = [D(D(x)) ~ -b / M * D(x) - k / M * x]
ps = [M, b, k]
default_u0 = [D(x) => 0.0, x => 10.0]
default_p = [M => 1.0, b => 1.0, k => 1.0]
@named sys = ODESystem(eqs, t, [x], ps, defaults = [default_u0; default_p])
@named sys = ODESystem(eqs, t, [x], ps; defaults = [default_u0; default_p], tspan)
sys = ode_order_lowering(sys)
prob = ODEProblem(sys, [], tspan)
prob = ODEProblem(sys)
sol = solve(prob, Tsit5())
@test sol.t[end] == tspan[end]
@test sum(abs, sol[end]) < 1

# check_eqs_u0 kwarg test
Expand Down
3 changes: 2 additions & 1 deletion test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ noiseeqs = [0.1 * x,
@named sys = ODESystem(eqs, t, [x, y, z], [σ, ρ, β])
@test SDESystem(sys, noiseeqs, name = :foo) isa SDESystem

@named de = SDESystem(eqs, noiseeqs, t, [x, y, z], [σ, ρ, β])
@named de = SDESystem(eqs, noiseeqs, t, [x, y, z], [σ, ρ, β], tspan = (0.0, 10.0))
f = eval(generate_diffusion_function(de)[1])
@test f(ones(3), rand(3), nothing) == 0.1ones(3)

Expand All @@ -36,6 +36,7 @@ solexpr = solve(eval(probexpr), SRIW1(), seed = 1)

# Test no error
@test_nowarn SDEProblem(de, nothing, (0, 10.0))
@test SDEProblem(de, nothing).tspan == (0.0, 10.0)

noiseeqs_nd = [0.01*x 0.01*x*y 0.02*x*z
σ 0.01*y 0.02*x*z
Expand Down

0 comments on commit 00590bf

Please sign in to comment.