From 85fb196fd934948b1fa189c72c8a03d70b5a5db5 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 21 Dec 2020 06:58:54 -0500 Subject: [PATCH 1/7] alias elimination simplified --- src/systems/reduction.jl | 70 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/src/systems/reduction.jl b/src/systems/reduction.jl index 0fa239c105..9010563f36 100644 --- a/src/systems/reduction.jl +++ b/src/systems/reduction.jl @@ -83,3 +83,73 @@ function alias_elimination(sys::ODESystem) ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=eliminate .~ outputs) end +function get_α_x(αx) + if isvar(αx) + return αx, 1 + elseif αx isa Term && operation(αx) === (*) + args = arguments(αx) + nums = filter(!isvar, args) + syms = filter(isvar, args) + + if length(syms) == 1 + return syms[1], prod(nums) + end + else + return nothing + end +end + +function alias_elimination2(sys) + eqs = vcat(equations(sys), observed(sys)) + + subs = Pair[] + # Case 1: Right hand side is a constant + ii = findall(eqs) do eq + (eq.lhs isa Sym || (eq.lhs isa Term && !(eq.lhs.op isa Differential))) && !(eq.rhs isa Symbolic) + end + for eq in eqs[ii] + substitution_dict[eq.lhs] = eq.rhs + push!(subs, eq.lhs => eq.rhs) + end + deleteat!(eqs, ii) # remove them + + # Case 2: One side is a differentiated var, the other is an algebraic var + # substitute the algebraic var with the diff var + diff_vars = findall(eqs) do eq + if eq.lhs isa Term && eq.lhs.op isa Differential + eq.lhs.args[1] + else + nothing + end + end + + for eq in eqs + res_left = get_α_x(eq.lhs) + if !isnothing(res) + res_right = get_α_x(eq.rhs) + β, y = res + if y in diff_vars && !(x in diff_vars) + multiple = β / α + push!(subs, x => isone(multiple) ? y : multiple * y) + elseif x in diff_vars && !(y in diff_vars) + multiple = α / β + push!(subs, y => isone(multiple) ? y : multiple * y) + end + end + end + + # Case 3: Explicit substitutions + for eq in eqs + res_left = get_α_x(eq.lhs) + if !isnothing(res) + res_right = get_α_x(eq.rhs) + β, y = res + multiple = β / α + push!(subs, x => isone(multiple) ? x : multiple * x) + end + end + + diffeqs = filter(eq -> eq.lhs isa Term && eq.lhs.op isa Differential, eqs) + diffeqs′ = substitute_aliases(diffeqs, Dict(subs)) + ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=first.(subs) .~ last.(subs)) +end From 62ecf956183419a0381fb90b4a8ff5a873d9ef4e Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Mon, 21 Dec 2020 12:31:32 -0500 Subject: [PATCH 2/7] fixes and first few tests --- src/systems/reduction.jl | 77 ++++++++++++++++++++++++---------------- test/reduction.jl | 17 +++++---- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/src/systems/reduction.jl b/src/systems/reduction.jl index 9010563f36..35ddef360a 100644 --- a/src/systems/reduction.jl +++ b/src/systems/reduction.jl @@ -27,9 +27,9 @@ function substitute_aliases(diffeqs, dict) lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,)) end -isvar(s::Sym) = !isparameter(s) -isvar(s::Term) = isvar(s.op) -isvar(s::Any) = false +isvar(s::Sym; param=false) = param ? true : !isparameter(s) +isvar(s::Term; param=false) = isvar(s.op; param=param) +isvar(s::Any;param=false) = false function filterexpr(f, s) vs = [] @@ -84,15 +84,15 @@ function alias_elimination(sys::ODESystem) end function get_α_x(αx) - if isvar(αx) - return αx, 1 + if isvar(αx, param=true) + return 1, αx elseif αx isa Term && operation(αx) === (*) args = arguments(αx) nums = filter(!isvar, args) syms = filter(isvar, args) if length(syms) == 1 - return syms[1], prod(nums) + return prod(nums), syms[1] end else return nothing @@ -105,51 +105,68 @@ function alias_elimination2(sys) subs = Pair[] # Case 1: Right hand side is a constant ii = findall(eqs) do eq - (eq.lhs isa Sym || (eq.lhs isa Term && !(eq.lhs.op isa Differential))) && !(eq.rhs isa Symbolic) + !(eq.rhs isa Symbolic) end for eq in eqs[ii] - substitution_dict[eq.lhs] = eq.rhs - push!(subs, eq.lhs => eq.rhs) + α,x = get_α_x(eq.lhs) + push!(subs, x => isone(α) ? eq.rhs : eq.rhs / α) end deleteat!(eqs, ii) # remove them # Case 2: One side is a differentiated var, the other is an algebraic var # substitute the algebraic var with the diff var - diff_vars = findall(eqs) do eq - if eq.lhs isa Term && eq.lhs.op isa Differential - eq.lhs.args[1] - else - nothing - end - end + diff_vars = filter(!isnothing, map(eqs) do eq + if eq.lhs isa Term && eq.lhs.op isa Differential + eq.lhs.args[1] + else + nothing + end + end) |> Set - for eq in eqs + del = Int[] + for (i, eq) in enumerate(eqs) res_left = get_α_x(eq.lhs) - if !isnothing(res) + if !isnothing(res_left) + α, x = res_left res_right = get_α_x(eq.rhs) - β, y = res - if y in diff_vars && !(x in diff_vars) - multiple = β / α - push!(subs, x => isone(multiple) ? y : multiple * y) - elseif x in diff_vars && !(y in diff_vars) - multiple = α / β - push!(subs, y => isone(multiple) ? y : multiple * y) + if !isnothing(res_right) + β, y = res_right + if y in diff_vars && !(x in diff_vars) + multiple = β / α + push!(subs, x => isone(multiple) ? y : multiple * y) + push!(del, i) + elseif x in diff_vars && !(y in diff_vars) + multiple = α / β + push!(subs, y => isone(multiple) ? x : multiple * x) + push!(del, i) + end end end end + deleteat!(eqs, del) # Case 3: Explicit substitutions - for eq in eqs + del = Int[] + for (i, eq) in enumerate(eqs) res_left = get_α_x(eq.lhs) - if !isnothing(res) + if !isnothing(res_left) + α, x = res_left res_right = get_α_x(eq.rhs) - β, y = res - multiple = β / α - push!(subs, x => isone(multiple) ? x : multiple * x) + if !isnothing(res_right) + β, y = res_right + multiple = β / α + push!(subs, x => _isone(multiple) ? x : multiple * x) + push!(del, i) + end end end + deleteat!(eqs, del) diffeqs = filter(eq -> eq.lhs isa Term && eq.lhs.op isa Differential, eqs) diffeqs′ = substitute_aliases(diffeqs, Dict(subs)) + + newstates = map(diffeqs) do eq + eq.lhs.args[1] + end ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=first.(subs) .~ last.(subs)) end diff --git a/test/reduction.jl b/test/reduction.jl index af6d1c8a51..d0ff577e45 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -1,4 +1,5 @@ using ModelingToolkit, OrdinaryDiffEq, Test +using ModelingToolkit: alias_elimination2 @parameters t σ ρ β @variables x(t) y(t) z(t) a(t) u(t) F(t) @@ -9,19 +10,21 @@ test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynor eqs = [D(x) ~ σ*(y-x), D(y) ~ x*(ρ-z)-y, D(z) ~ a*y - β*z, - 0 ~ x - a] + β ~ 2, + x ~ a] lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1) -lorenz1_aliased = alias_elimination(lorenz1) +lorenz1_aliased = alias_elimination2(lorenz1) @test length(equations(lorenz1_aliased)) == 3 @test length(states(lorenz1_aliased)) == 3 eqs = [D(x) ~ σ*(y-x), D(y) ~ x*(ρ-z)-y, - D(z) ~ x*y - β*z] + D(z) ~ x*y - 2*z] -@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=[a ~ x],name=:lorenz1) +# TODO: maybe remove β from ps, or maybe don't allow this example on params +@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=[β ~ 2, a ~ x],name=:lorenz1) # Multi-System Reduction @@ -44,7 +47,7 @@ lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2) connections = [lorenz1.F ~ lorenz2.u, lorenz2.F ~ lorenz1.u] -connected = ODESystem([0 ~ a + lorenz1.x - lorenz2.y],t,[a],[],observed=connections,systems=[lorenz1,lorenz2]) +connected = ODESystem([lorenz2.y ~ a + lorenz1.x ],t,[a],[],observed=connections,systems=[lorenz1,lorenz2]) # Reduced Unflattened System #= @@ -59,7 +62,7 @@ connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1, flattened_system = ModelingToolkit.flatten(connected) -aliased_flattened_system = alias_elimination(flattened_system) +aliased_flattened_system = alias_elimination2(flattened_system) @test isequal(states(aliased_flattened_system), [ lorenz1.x @@ -107,7 +110,7 @@ let x ~ y ]; sys = ODESystem(eqs, t, [x], []); - asys = alias_elimination(ModelingToolkit.flatten(sys)) + asys = alias_elimination2(ModelingToolkit.flatten(sys)) test_equal.(asys.eqs, [D(x) ~ 2x]) test_equal.(asys.observed, [y ~ x]) From 6466d8ade57126e2eed16e55d74e0b4434e660e7 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 23 Dec 2020 20:06:31 -0500 Subject: [PATCH 3/7] Make alias_elimination more powerful --- src/systems/diffeqs/odesystem.jl | 2 ++ src/systems/reduction.jl | 13 ++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e09e312ed3..882bc56c48 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -149,3 +149,5 @@ Base.:(==)(sys1::ODESystem, sys2::ODESystem) = function rename(sys::ODESystem,name) ODESystem(sys.eqs, sys.iv, sys.states, sys.ps, sys.pins, sys.observed, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems) end + +isdiffeq(eq) = eq.lhs isa Term && operation(eq.lhs) isa Differential diff --git a/src/systems/reduction.jl b/src/systems/reduction.jl index 35ddef360a..e9ddbacdf9 100644 --- a/src/systems/reduction.jl +++ b/src/systems/reduction.jl @@ -6,6 +6,8 @@ function flatten(sys::ODESystem) else return ODESystem(equations(sys), independent_variable(sys), + states(sys), + parameters(sys), observed=observed(sys)) end end @@ -116,7 +118,7 @@ function alias_elimination2(sys) # Case 2: One side is a differentiated var, the other is an algebraic var # substitute the algebraic var with the diff var diff_vars = filter(!isnothing, map(eqs) do eq - if eq.lhs isa Term && eq.lhs.op isa Differential + if isdiffeq(eq) eq.lhs.args[1] else nothing @@ -148,14 +150,19 @@ function alias_elimination2(sys) # Case 3: Explicit substitutions del = Int[] for (i, eq) in enumerate(eqs) + isdiffeq(eq) && continue res_left = get_α_x(eq.lhs) if !isnothing(res_left) + # `α x = rhs` => `x = rhs / α` α, x = res_left + push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α) + push!(del, i) + else res_right = get_α_x(eq.rhs) if !isnothing(res_right) + # `lhs = β y` => `y = lhs / β` β, y = res_right - multiple = β / α - push!(subs, x => _isone(multiple) ? x : multiple * x) + push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs) push!(del, i) end end From 9654bb21c33e6d29ec72d194337353a4d1a021cc Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 23 Dec 2020 20:09:05 -0500 Subject: [PATCH 4/7] Update alias_elimination tests --- src/systems/reduction.jl | 54 +--------------------------------------- test/reduction.jl | 42 ++++++++++++++++--------------- 2 files changed, 23 insertions(+), 73 deletions(-) diff --git a/src/systems/reduction.jl b/src/systems/reduction.jl index e9ddbacdf9..b81450acfd 100644 --- a/src/systems/reduction.jl +++ b/src/systems/reduction.jl @@ -33,58 +33,6 @@ isvar(s::Sym; param=false) = param ? true : !isparameter(s) isvar(s::Term; param=false) = isvar(s.op; param=param) isvar(s::Any;param=false) = false -function filterexpr(f, s) - vs = [] - Rewriters.Prewalk(Rewriters.Chain([@rule((~x::f) => push!(vs, ~x))]))(s) - vs -end - -function make_lhs_0(eq) - if eq.lhs isa Number && iszero(eq.lhs) - return eq - else - 0 ~ eq.lhs - eq.rhs - end -end - -function alias_elimination(sys::ODESystem) - eqs = vcat(equations(sys), observed(sys)) - - # make all algebraic equations have 0 on LHS - eqs = map(eqs) do eq - if eq.lhs isa Term && eq.lhs.op isa Differential - eq - else - make_lhs_0(eq) - end - end - - newstates = map(eqs) do eq - if eq.lhs isa Term && eq.lhs.op isa Differential - filterexpr(isvar, eq.lhs) - else - [] - end - end |> Iterators.flatten |> collect |> unique - - - all_vars = map(eqs) do eq - filterexpr(isvar, eq.rhs) - end |> Iterators.flatten |> collect |> unique - - alg_idxs = findall(x->!(x.lhs isa Term) && iszero(x.lhs), eqs) - - eliminate = setdiff(all_vars, newstates) - - outputs = solve_for(eqs[alg_idxs], eliminate) - - diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)] - - diffeqs′ = substitute_aliases(diffeqs, Dict(eliminate .=> outputs)) - - ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=eliminate .~ outputs) -end - function get_α_x(αx) if isvar(αx, param=true) return 1, αx @@ -101,7 +49,7 @@ function get_α_x(αx) end end -function alias_elimination2(sys) +function alias_elimination(sys) eqs = vcat(equations(sys), observed(sys)) subs = Pair[] diff --git a/test/reduction.jl b/test/reduction.jl index d0ff577e45..070fb34ef3 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -1,5 +1,4 @@ using ModelingToolkit, OrdinaryDiffEq, Test -using ModelingToolkit: alias_elimination2 @parameters t σ ρ β @variables x(t) y(t) z(t) a(t) u(t) F(t) @@ -15,7 +14,7 @@ eqs = [D(x) ~ σ*(y-x), lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1) -lorenz1_aliased = alias_elimination2(lorenz1) +lorenz1_aliased = alias_elimination(lorenz1) @test length(equations(lorenz1_aliased)) == 3 @test length(states(lorenz1_aliased)) == 3 @@ -47,7 +46,7 @@ lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2) connections = [lorenz1.F ~ lorenz2.u, lorenz2.F ~ lorenz1.u] -connected = ODESystem([lorenz2.y ~ a + lorenz1.x ],t,[a],[],observed=connections,systems=[lorenz1,lorenz2]) +connected = ODESystem([lorenz2.y ~ a + lorenz1.x],t,[a],[],observed=connections,systems=[lorenz1,lorenz2]) # Reduced Unflattened System #= @@ -62,7 +61,7 @@ connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1, flattened_system = ModelingToolkit.flatten(connected) -aliased_flattened_system = alias_elimination2(flattened_system) +aliased_flattened_system = alias_elimination(flattened_system) @test isequal(states(aliased_flattened_system), [ lorenz1.x @@ -83,21 +82,24 @@ aliased_flattened_system = alias_elimination2(flattened_system) lorenz2.β ]) |> isempty -test_equal.(equations(aliased_flattened_system), [ - D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z, - D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z), - D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z, - D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z, - D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x, - D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z]) - -test_equal.(observed(aliased_flattened_system), [ - lorenz1.F ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z), - lorenz1.u ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z, - lorenz2.F ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z, - a ~ lorenz2.y + -1 * lorenz1.x, - lorenz2.u ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z), -]) +reduced_eqs = [ + D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - (a + lorenz1.x) - lorenz2.z, + D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z), + D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z, + D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z, + D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x, + D(lorenz2.z) ~ lorenz2.x*(a + lorenz1.x) - lorenz2.β*lorenz2.z + ] +test_equal.(equations(aliased_flattened_system), reduced_eqs) + +observed_eqs = [ + lorenz2.y ~ a + lorenz1.x, + lorenz1.F ~ lorenz2.u, + lorenz2.F ~ lorenz1.u, + lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z, + lorenz2.u ~ lorenz2.x - lorenz2.y - lorenz2.z, + ] +test_equal.(observed(aliased_flattened_system), observed_eqs) # issue #578 @@ -110,7 +112,7 @@ let x ~ y ]; sys = ODESystem(eqs, t, [x], []); - asys = alias_elimination2(ModelingToolkit.flatten(sys)) + asys = alias_elimination(ModelingToolkit.flatten(sys)) test_equal.(asys.eqs, [D(x) ~ 2x]) test_equal.(asys.observed, [y ~ x]) From d4b8d25e5b2ec6865d9554ce3d3ec0987acc902e Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 23 Dec 2020 21:43:52 -0500 Subject: [PATCH 5/7] Address code review comments --- src/systems/diffeqs/abstractodesystem.jl | 2 ++ src/systems/reduction.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 388ad9c3de..b13c62605a 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -418,3 +418,5 @@ end function SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...) SteadyStateProblemExpr{true}(sys, args...; kwargs...) end + +isdiffeq(eq) = eq.lhs isa Term && operation(eq.lhs) isa Differential diff --git a/src/systems/reduction.jl b/src/systems/reduction.jl index b81450acfd..adeb361bf8 100644 --- a/src/systems/reduction.jl +++ b/src/systems/reduction.jl @@ -49,7 +49,7 @@ function get_α_x(αx) end end -function alias_elimination(sys) +function alias_elimination(sys::ODESystem) eqs = vcat(equations(sys), observed(sys)) subs = Pair[] From 6671ae654c4546c625598f23e5862acf09567a43 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 23 Dec 2020 21:51:34 -0500 Subject: [PATCH 6/7] Fix alias elimination --- src/systems/diffeqs/odesystem.jl | 2 - src/systems/reduction.jl | 68 +++++++------------------- test/reduction.jl | 83 +++++++++++++++----------------- 3 files changed, 58 insertions(+), 95 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 882bc56c48..e09e312ed3 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -149,5 +149,3 @@ Base.:(==)(sys1::ODESystem, sys2::ODESystem) = function rename(sys::ODESystem,name) ODESystem(sys.eqs, sys.iv, sys.states, sys.ps, sys.pins, sys.observed, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems) end - -isdiffeq(eq) = eq.lhs isa Term && operation(eq.lhs) isa Differential diff --git a/src/systems/reduction.jl b/src/systems/reduction.jl index adeb361bf8..f66e22c9a9 100644 --- a/src/systems/reduction.jl +++ b/src/systems/reduction.jl @@ -29,17 +29,22 @@ function substitute_aliases(diffeqs, dict) lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,)) end -isvar(s::Sym; param=false) = param ? true : !isparameter(s) -isvar(s::Term; param=false) = isvar(s.op; param=param) -isvar(s::Any;param=false) = false +# Note that we reduce parameters, too +# i.e. `2param = 3` will be reduced away +isvar(s::Sym) = true +isvar(s::Term) = isvar(operation(s)) +isvar(s::Any) = false function get_α_x(αx) - if isvar(αx, param=true) + if isvar(αx) return 1, αx elseif αx isa Term && operation(αx) === (*) args = arguments(αx) - nums = filter(!isvar, args) - syms = filter(isvar, args) + nums = [] + syms = [] + for arg in args + isvar(arg) ? push!(syms, arg) : push!(nums, arg) + end if length(syms) == 1 return prod(nums), syms[1] @@ -51,20 +56,7 @@ end function alias_elimination(sys::ODESystem) eqs = vcat(equations(sys), observed(sys)) - subs = Pair[] - # Case 1: Right hand side is a constant - ii = findall(eqs) do eq - !(eq.rhs isa Symbolic) - end - for eq in eqs[ii] - α,x = get_α_x(eq.lhs) - push!(subs, x => isone(α) ? eq.rhs : eq.rhs / α) - end - deleteat!(eqs, ii) # remove them - - # Case 2: One side is a differentiated var, the other is an algebraic var - # substitute the algebraic var with the diff var diff_vars = filter(!isnothing, map(eqs) do eq if isdiffeq(eq) eq.lhs.args[1] @@ -73,41 +65,19 @@ function alias_elimination(sys::ODESystem) end end) |> Set - del = Int[] - for (i, eq) in enumerate(eqs) - res_left = get_α_x(eq.lhs) - if !isnothing(res_left) - α, x = res_left - res_right = get_α_x(eq.rhs) - if !isnothing(res_right) - β, y = res_right - if y in diff_vars && !(x in diff_vars) - multiple = β / α - push!(subs, x => isone(multiple) ? y : multiple * y) - push!(del, i) - elseif x in diff_vars && !(y in diff_vars) - multiple = α / β - push!(subs, y => isone(multiple) ? x : multiple * x) - push!(del, i) - end - end - end - end - deleteat!(eqs, del) - - # Case 3: Explicit substitutions + # only substitute when the variable is algebraic del = Int[] for (i, eq) in enumerate(eqs) isdiffeq(eq) && continue res_left = get_α_x(eq.lhs) - if !isnothing(res_left) + if !isnothing(res_left) && !(res_left[2] in diff_vars) # `α x = rhs` => `x = rhs / α` α, x = res_left push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α) push!(del, i) else res_right = get_α_x(eq.rhs) - if !isnothing(res_right) + if !isnothing(res_right) && !(res_right[2] in diff_vars) # `lhs = β y` => `y = lhs / β` β, y = res_right push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs) @@ -117,11 +87,9 @@ function alias_elimination(sys::ODESystem) end deleteat!(eqs, del) - diffeqs = filter(eq -> eq.lhs isa Term && eq.lhs.op isa Differential, eqs) - diffeqs′ = substitute_aliases(diffeqs, Dict(subs)) + eqs′ = substitute_aliases(eqs, Dict(subs)) + alias_vars = first.(subs) - newstates = map(diffeqs) do eq - eq.lhs.args[1] - end - ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=first.(subs) .~ last.(subs)) + newstates = setdiff(states(sys), alias_vars) + ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs)) end diff --git a/test/reduction.jl b/test/reduction.jl index 070fb34ef3..2478d5df20 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -7,55 +7,52 @@ using ModelingToolkit, OrdinaryDiffEq, Test test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true)) eqs = [D(x) ~ σ*(y-x), - D(y) ~ x*(ρ-z)-y, - D(z) ~ a*y - β*z, - β ~ 2, - x ~ a] + D(y) ~ x*(ρ-z)-y + β, + 0 ~ sin(z) - x + y, + sin(u) ~ x + y, + 2β ~ 2, + x ~ a, + ] -lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1) +lorenz1 = ODESystem(eqs,t,[u,x,y,z,a],[σ,ρ,β],name=:lorenz1) lorenz1_aliased = alias_elimination(lorenz1) -@test length(equations(lorenz1_aliased)) == 3 -@test length(states(lorenz1_aliased)) == 3 - -eqs = [D(x) ~ σ*(y-x), - D(y) ~ x*(ρ-z)-y, - D(z) ~ x*y - 2*z] - -# TODO: maybe remove β from ps, or maybe don't allow this example on params -@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=[β ~ 2, a ~ x],name=:lorenz1) +reduced_eqs = [ + D(x) ~ σ * (y - x), + D(y) ~ x*(ρ-z)-y + 1, + 0 ~ sin(z) - x + y, + sin(u) ~ x + y, + ] +test_equal.(equations(lorenz1_aliased), reduced_eqs) +test_equal.(states(lorenz1_aliased), [u, x, y, z]) +test_equal.(observed(lorenz1_aliased), [ + β ~ 1, + a ~ x, +]) # Multi-System Reduction -eqs1 = [D(x) ~ σ*(y-x) + F, - D(y) ~ x*(ρ-z)-u, - D(z) ~ x*y - β*z] - -aliases = [u ~ x + y - z] +eqs1 = [ + D(x) ~ σ*(y-x) + F, + D(y) ~ x*(ρ-z)-u, + D(z) ~ x*y - β*z, + u ~ x + y - z, + ] -lorenz1 = ODESystem(eqs1,pins=[F],observed=aliases,name=:lorenz1) +lorenz1 = ODESystem(eqs1,pins=[F],name=:lorenz1) -eqs2 = [D(x) ~ F, - D(y) ~ x*(ρ-z)-x, - D(z) ~ x*y - β*z] +eqs2 = [ + D(x) ~ F, + D(y) ~ x*(ρ-z)-x, + D(z) ~ x*y - β*z, + u ~ x - y - z + ] -aliases2 = [u ~ x - y - z] +lorenz2 = ODESystem(eqs2,pins=[F],name=:lorenz2) -lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2) - -connections = [lorenz1.F ~ lorenz2.u, - lorenz2.F ~ lorenz1.u] - -connected = ODESystem([lorenz2.y ~ a + lorenz1.x],t,[a],[],observed=connections,systems=[lorenz1,lorenz2]) - -# Reduced Unflattened System -#= - -connections2 = [lorenz1.F ~ lorenz2.u, - lorenz2.F ~ lorenz1.u, - a ~ -lorenz1.x + lorenz2.y] -connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1,lorenz2]) -=# +connected = ODESystem([lorenz2.y ~ a + lorenz1.x, + lorenz1.F ~ lorenz2.u, + lorenz2.F ~ lorenz1.u],t,[a],[],systems=[lorenz1,lorenz2]) # Reduced Flattened System @@ -64,6 +61,7 @@ flattened_system = ModelingToolkit.flatten(connected) aliased_flattened_system = alias_elimination(flattened_system) @test isequal(states(aliased_flattened_system), [ + a lorenz1.x lorenz1.y lorenz1.z @@ -83,17 +81,17 @@ aliased_flattened_system = alias_elimination(flattened_system) ]) |> isempty reduced_eqs = [ - D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - (a + lorenz1.x) - lorenz2.z, + lorenz2.y ~ a + lorenz1.x, # irreducible by alias elimination + D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z, D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z), D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z, D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z, D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x, - D(lorenz2.z) ~ lorenz2.x*(a + lorenz1.x) - lorenz2.β*lorenz2.z + D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z ] test_equal.(equations(aliased_flattened_system), reduced_eqs) observed_eqs = [ - lorenz2.y ~ a + lorenz1.x, lorenz1.F ~ lorenz2.u, lorenz2.F ~ lorenz1.u, lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z, @@ -101,7 +99,6 @@ observed_eqs = [ ] test_equal.(observed(aliased_flattened_system), observed_eqs) - # issue #578 let From 80d02c62bc6bf678b13cd7d47c919113b22dda49 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 24 Dec 2020 18:41:54 -0500 Subject: [PATCH 7/7] Allow `x ~ y` in ODESystem --- src/systems/diffeqs/odesystem.jl | 28 +++++++++++++++++----------- test/odesystem.jl | 11 +++++++++++ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e09e312ed3..1772abe759 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -119,21 +119,16 @@ function ODESystem(eqs, iv=nothing; kwargs...) end iv === nothing && throw(ArgumentError("Please pass in independent variables.")) for eq in eqs - for var in vars(eq.rhs for eq ∈ eqs) - if isparameter(var) || isparameter(var.op) - isequal(var, iv) || push!(ps, var) - else - push!(allstates, var) - end - end - if !(eq.lhs isa Symbolic) - push!(algeeq, eq) - else - diffvar = first(var_from_nested_derivative(eq.lhs)) + collect_vars!(allstates, ps, eq.lhs, iv) + collect_vars!(allstates, ps, eq.rhs, iv) + if isdiffeq(eq) + diffvar, _ = var_from_nested_derivative(eq.lhs) isequal(iv, iv_from_nested_derivative(eq.lhs)) || throw(ArgumentError("An ODESystem can only have one independent variable.")) diffvar in diffvars && throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations.")) push!(diffvars, diffvar) push!(diffeq, eq) + else + push!(algeeq, eq) end end algevars = setdiff(allstates, diffvars) @@ -141,6 +136,17 @@ function ODESystem(eqs, iv=nothing; kwargs...) return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...) end +function collect_vars!(states, parameters, expr, iv) + for var in vars(expr) + if isparameter(var) || isparameter(var.op) + isequal(var, iv) || push!(parameters, var) + else + push!(states, var) + end + end + return nothing +end + Base.:(==)(sys1::ODESystem, sys2::ODESystem) = _eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) && _eq_unordered(sys1.states, sys2.states) && _eq_unordered(sys1.ps, sys2.ps) diff --git a/test/odesystem.jl b/test/odesystem.jl index 7fd72f203d..d6e1f76411 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -200,3 +200,14 @@ for (prob, atol) in [(prob1, 1e-12), (prob2, 1e-12), (prob3, 1e-12)] sol = solve(prob, Rodas5()) @test all(x->≈(sum(x), 1.0, atol=atol), sol.u) end + +@parameters t σ β +@variables x(t) y(t) z(t) +@derivatives D'~t +eqs = [D(x) ~ σ*(y-x), + D(y) ~ x-β*y, + x + z ~ y] +sys = ODESystem(eqs) +@test all(isequal.(states(sys), [x, y, z])) +@test all(isequal.(parameters(sys), [σ, β])) +@test equations(sys) == eqs