diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 388ad9c3de..b13c62605a 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -418,3 +418,5 @@ end function SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...) SteadyStateProblemExpr{true}(sys, args...; kwargs...) end + +isdiffeq(eq) = eq.lhs isa Term && operation(eq.lhs) isa Differential diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e09e312ed3..1772abe759 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -119,21 +119,16 @@ function ODESystem(eqs, iv=nothing; kwargs...) end iv === nothing && throw(ArgumentError("Please pass in independent variables.")) for eq in eqs - for var in vars(eq.rhs for eq ∈ eqs) - if isparameter(var) || isparameter(var.op) - isequal(var, iv) || push!(ps, var) - else - push!(allstates, var) - end - end - if !(eq.lhs isa Symbolic) - push!(algeeq, eq) - else - diffvar = first(var_from_nested_derivative(eq.lhs)) + collect_vars!(allstates, ps, eq.lhs, iv) + collect_vars!(allstates, ps, eq.rhs, iv) + if isdiffeq(eq) + diffvar, _ = var_from_nested_derivative(eq.lhs) isequal(iv, iv_from_nested_derivative(eq.lhs)) || throw(ArgumentError("An ODESystem can only have one independent variable.")) diffvar in diffvars && throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations.")) push!(diffvars, diffvar) push!(diffeq, eq) + else + push!(algeeq, eq) end end algevars = setdiff(allstates, diffvars) @@ -141,6 +136,17 @@ function ODESystem(eqs, iv=nothing; kwargs...) return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...) end +function collect_vars!(states, parameters, expr, iv) + for var in vars(expr) + if isparameter(var) || isparameter(var.op) + isequal(var, iv) || push!(parameters, var) + else + push!(states, var) + end + end + return nothing +end + Base.:(==)(sys1::ODESystem, sys2::ODESystem) = _eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) && _eq_unordered(sys1.states, sys2.states) && _eq_unordered(sys1.ps, sys2.ps) diff --git a/src/systems/reduction.jl b/src/systems/reduction.jl index 0fa239c105..f66e22c9a9 100644 --- a/src/systems/reduction.jl +++ b/src/systems/reduction.jl @@ -6,6 +6,8 @@ function flatten(sys::ODESystem) else return ODESystem(equations(sys), independent_variable(sys), + states(sys), + parameters(sys), observed=observed(sys)) end end @@ -27,59 +29,67 @@ function substitute_aliases(diffeqs, dict) lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,)) end -isvar(s::Sym) = !isparameter(s) -isvar(s::Term) = isvar(s.op) +# Note that we reduce parameters, too +# i.e. `2param = 3` will be reduced away +isvar(s::Sym) = true +isvar(s::Term) = isvar(operation(s)) isvar(s::Any) = false -function filterexpr(f, s) - vs = [] - Rewriters.Prewalk(Rewriters.Chain([@rule((~x::f) => push!(vs, ~x))]))(s) - vs -end +function get_α_x(αx) + if isvar(αx) + return 1, αx + elseif αx isa Term && operation(αx) === (*) + args = arguments(αx) + nums = [] + syms = [] + for arg in args + isvar(arg) ? push!(syms, arg) : push!(nums, arg) + end -function make_lhs_0(eq) - if eq.lhs isa Number && iszero(eq.lhs) - return eq + if length(syms) == 1 + return prod(nums), syms[1] + end else - 0 ~ eq.lhs - eq.rhs + return nothing end end function alias_elimination(sys::ODESystem) eqs = vcat(equations(sys), observed(sys)) - - # make all algebraic equations have 0 on LHS - eqs = map(eqs) do eq - if eq.lhs isa Term && eq.lhs.op isa Differential - eq + subs = Pair[] + diff_vars = filter(!isnothing, map(eqs) do eq + if isdiffeq(eq) + eq.lhs.args[1] + else + nothing + end + end) |> Set + + # only substitute when the variable is algebraic + del = Int[] + for (i, eq) in enumerate(eqs) + isdiffeq(eq) && continue + res_left = get_α_x(eq.lhs) + if !isnothing(res_left) && !(res_left[2] in diff_vars) + # `α x = rhs` => `x = rhs / α` + α, x = res_left + push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α) + push!(del, i) else - make_lhs_0(eq) + res_right = get_α_x(eq.rhs) + if !isnothing(res_right) && !(res_right[2] in diff_vars) + # `lhs = β y` => `y = lhs / β` + β, y = res_right + push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs) + push!(del, i) + end end end + deleteat!(eqs, del) - newstates = map(eqs) do eq - if eq.lhs isa Term && eq.lhs.op isa Differential - filterexpr(isvar, eq.lhs) - else - [] - end - end |> Iterators.flatten |> collect |> unique - + eqs′ = substitute_aliases(eqs, Dict(subs)) + alias_vars = first.(subs) - all_vars = map(eqs) do eq - filterexpr(isvar, eq.rhs) - end |> Iterators.flatten |> collect |> unique - - alg_idxs = findall(x->!(x.lhs isa Term) && iszero(x.lhs), eqs) - - eliminate = setdiff(all_vars, newstates) - - outputs = solve_for(eqs[alg_idxs], eliminate) - - diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)] - - diffeqs′ = substitute_aliases(diffeqs, Dict(eliminate .=> outputs)) - - ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=eliminate .~ outputs) + newstates = setdiff(states(sys), alias_vars) + ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs)) end - diff --git a/test/odesystem.jl b/test/odesystem.jl index 7fd72f203d..d6e1f76411 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -200,3 +200,14 @@ for (prob, atol) in [(prob1, 1e-12), (prob2, 1e-12), (prob3, 1e-12)] sol = solve(prob, Rodas5()) @test all(x->≈(sum(x), 1.0, atol=atol), sol.u) end + +@parameters t σ β +@variables x(t) y(t) z(t) +@derivatives D'~t +eqs = [D(x) ~ σ*(y-x), + D(y) ~ x-β*y, + x + z ~ y] +sys = ODESystem(eqs) +@test all(isequal.(states(sys), [x, y, z])) +@test all(isequal.(parameters(sys), [σ, β])) +@test equations(sys) == eqs diff --git a/test/reduction.jl b/test/reduction.jl index af6d1c8a51..2478d5df20 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -7,53 +7,52 @@ using ModelingToolkit, OrdinaryDiffEq, Test test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true)) eqs = [D(x) ~ σ*(y-x), - D(y) ~ x*(ρ-z)-y, - D(z) ~ a*y - β*z, - 0 ~ x - a] + D(y) ~ x*(ρ-z)-y + β, + 0 ~ sin(z) - x + y, + sin(u) ~ x + y, + 2β ~ 2, + x ~ a, + ] -lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1) +lorenz1 = ODESystem(eqs,t,[u,x,y,z,a],[σ,ρ,β],name=:lorenz1) lorenz1_aliased = alias_elimination(lorenz1) -@test length(equations(lorenz1_aliased)) == 3 -@test length(states(lorenz1_aliased)) == 3 - -eqs = [D(x) ~ σ*(y-x), - D(y) ~ x*(ρ-z)-y, - D(z) ~ x*y - β*z] - -@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=[a ~ x],name=:lorenz1) +reduced_eqs = [ + D(x) ~ σ * (y - x), + D(y) ~ x*(ρ-z)-y + 1, + 0 ~ sin(z) - x + y, + sin(u) ~ x + y, + ] +test_equal.(equations(lorenz1_aliased), reduced_eqs) +test_equal.(states(lorenz1_aliased), [u, x, y, z]) +test_equal.(observed(lorenz1_aliased), [ + β ~ 1, + a ~ x, +]) # Multi-System Reduction -eqs1 = [D(x) ~ σ*(y-x) + F, - D(y) ~ x*(ρ-z)-u, - D(z) ~ x*y - β*z] - -aliases = [u ~ x + y - z] - -lorenz1 = ODESystem(eqs1,pins=[F],observed=aliases,name=:lorenz1) - -eqs2 = [D(x) ~ F, - D(y) ~ x*(ρ-z)-x, - D(z) ~ x*y - β*z] +eqs1 = [ + D(x) ~ σ*(y-x) + F, + D(y) ~ x*(ρ-z)-u, + D(z) ~ x*y - β*z, + u ~ x + y - z, + ] -aliases2 = [u ~ x - y - z] +lorenz1 = ODESystem(eqs1,pins=[F],name=:lorenz1) -lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2) +eqs2 = [ + D(x) ~ F, + D(y) ~ x*(ρ-z)-x, + D(z) ~ x*y - β*z, + u ~ x - y - z + ] -connections = [lorenz1.F ~ lorenz2.u, - lorenz2.F ~ lorenz1.u] +lorenz2 = ODESystem(eqs2,pins=[F],name=:lorenz2) -connected = ODESystem([0 ~ a + lorenz1.x - lorenz2.y],t,[a],[],observed=connections,systems=[lorenz1,lorenz2]) - -# Reduced Unflattened System -#= - -connections2 = [lorenz1.F ~ lorenz2.u, - lorenz2.F ~ lorenz1.u, - a ~ -lorenz1.x + lorenz2.y] -connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1,lorenz2]) -=# +connected = ODESystem([lorenz2.y ~ a + lorenz1.x, + lorenz1.F ~ lorenz2.u, + lorenz2.F ~ lorenz1.u],t,[a],[],systems=[lorenz1,lorenz2]) # Reduced Flattened System @@ -62,6 +61,7 @@ flattened_system = ModelingToolkit.flatten(connected) aliased_flattened_system = alias_elimination(flattened_system) @test isequal(states(aliased_flattened_system), [ + a lorenz1.x lorenz1.y lorenz1.z @@ -80,22 +80,24 @@ aliased_flattened_system = alias_elimination(flattened_system) lorenz2.β ]) |> isempty -test_equal.(equations(aliased_flattened_system), [ - D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z, - D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z), - D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z, - D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z, - D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x, - D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z]) - -test_equal.(observed(aliased_flattened_system), [ - lorenz1.F ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z), - lorenz1.u ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z, - lorenz2.F ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z, - a ~ lorenz2.y + -1 * lorenz1.x, - lorenz2.u ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z), -]) - +reduced_eqs = [ + lorenz2.y ~ a + lorenz1.x, # irreducible by alias elimination + D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z, + D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z), + D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z, + D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z, + D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x, + D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z + ] +test_equal.(equations(aliased_flattened_system), reduced_eqs) + +observed_eqs = [ + lorenz1.F ~ lorenz2.u, + lorenz2.F ~ lorenz1.u, + lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z, + lorenz2.u ~ lorenz2.x - lorenz2.y - lorenz2.z, + ] +test_equal.(observed(aliased_flattened_system), observed_eqs) # issue #578