Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Latexify, Unitful, ArrayInterface
using MacroTools
using UnPack: @unpack
using DiffEqJump
using DataStructures: OrderedDict, OrderedSet
using DataStructures
using SpecialFunctions, NaNMath
using RuntimeGeneratedFunctions
using Base.Threads
Expand Down
1 change: 1 addition & 0 deletions src/equations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct Equation
end
Base.:(==)(a::Equation, b::Equation) = all(isequal.((a.lhs, a.rhs), (b.lhs, b.rhs)))
Base.hash(a::Equation, salt::UInt) = hash(a.lhs, hash(a.rhs, salt))
Base.show(io::IO, eq::Equation) = print(io, eq.lhs, " ~ ", eq.rhs)

SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...)

Expand Down
125 changes: 114 additions & 11 deletions src/systems/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ end

function alias_elimination(sys::ODESystem)
eqs = vcat(equations(sys), observed(sys))
neweqs = Equation[]; sizehint!(neweqs, length(eqs))
subs = Pair[]
diff_vars = filter(!isnothing, map(eqs) do eq
if isdiffeq(eq)
Expand All @@ -65,31 +66,133 @@ function alias_elimination(sys::ODESystem)
end
end) |> Set

# only substitute when the variable is algebraic
del = Int[]
deps = Set()
for (i, eq) in enumerate(eqs)
isdiffeq(eq) && continue
# only substitute when the variable is algebraic
if isdiffeq(eq)
push!(neweqs, eq)
continue
end

maybe_alias = isalias = false
res_left = get_α_x(eq.lhs)
if !isnothing(res_left) && !(res_left[2] in diff_vars)
# `α x = rhs` => `x = rhs / α`
α, x = res_left
push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α)
push!(del, i)
sub = x => _isone(α) ? eq.rhs : eq.rhs / α
maybe_alias = true
else
res_right = get_α_x(eq.rhs)
if !isnothing(res_right) && !(res_right[2] in diff_vars)
# `lhs = β y` => `y = lhs / β`
β, y = res_right
push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs)
push!(del, i)
sub = y => _isone(β) ? eq.lhs : β * eq.lhs
maybe_alias = true
end
end

if maybe_alias
l, r = sub
# alias equations shouldn't introduce cycles
if !(l in deps) && isempty(intersect(deps, vars(r)))
push!(deps, l)
push!(subs, sub)
isalias = true
end
end

if !isalias
neweq = _iszero(eq.lhs) ? eq : 0 ~ eq.rhs - eq.lhs
push!(neweqs, neweq)
end
end
deleteat!(eqs, del)

eqs′ = substitute_aliases(eqs, Dict(subs))
eqs′ = substitute_aliases(neweqs, Dict(subs))

alias_vars = first.(subs)
sys_states = states(sys)
alias_eqs = alias_vars .~ last.(subs)
#alias_eqs = topsort_equations(alias_eqs, sys_states)

newstates = setdiff(sys_states, alias_vars)
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_eqs)
end

"""
$(SIGNATURES)

Use Kahn's algorithm to topologically sort observed equations.

Example:
```julia
julia> @variables t x(t) y(t) z(t) k(t)
(t, x(t), y(t), z(t), k(t))

julia> eqs = [
x ~ y + z
z ~ 2
y ~ 2z + k
];

julia> ModelingToolkit.topsort_equations(eqs, [x, y, z, k])
3-element Vector{Equation}:
Equation(z(t), 2)
Equation(y(t), k(t) + 2z(t))
Equation(x(t), y(t) + z(t))
```
"""
function topsort_equations(eqs, states; check=true)
graph, assigns = observed2graph(eqs, states)
neqs = length(eqs)
degrees = zeros(Int, neqs)

for 𝑠eq in 1:length(eqs); var = assigns[𝑠eq]
for 𝑑eq in 𝑑neighbors(graph, var)
# 𝑠eq => 𝑑eq
degrees[𝑑eq] += 1
end
end

q = Queue{Int}(neqs)
for (i, d) in enumerate(degrees)
d == 0 && enqueue!(q, i)
end

idx = 0
ordered_eqs = similar(eqs, 0); sizehint!(ordered_eqs, neqs)
while !isempty(q)
𝑠eq = dequeue!(q)
idx+=1
push!(ordered_eqs, eqs[𝑠eq])
var = assigns[𝑠eq]
for 𝑑eq in 𝑑neighbors(graph, var)
degree = degrees[𝑑eq] = degrees[𝑑eq] - 1
degree == 0 && enqueue!(q, 𝑑eq)
end
end

(check && idx != neqs) && throw(ArgumentError("The equations have at least one cycle."))

return ordered_eqs
end

function observed2graph(eqs, states)
graph = BipartiteGraph(length(eqs), length(states))
v2j = Dict(states .=> 1:length(states))

# `assigns: eq -> var`, `eq` defines `var`
assigns = similar(eqs, Int)

for (i, eq) in enumerate(eqs)
lhs_j = get(v2j, eq.lhs, nothing)
lhs_j === nothing && throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in states."))
assigns[i] = lhs_j
vs = vars(eq.rhs)
for v in vs
j = get(v2j, v, nothing)
j !== nothing && add_edge!(graph, i, j)
end
end

newstates = setdiff(states(sys), alias_vars)
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs))
return graph, assigns
end
49 changes: 39 additions & 10 deletions test/reduction.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,27 @@
using ModelingToolkit, OrdinaryDiffEq, Test
using ModelingToolkit: topsort_equations

@variables t x(t) y(t) z(t) k(t)
eqs = [
x ~ y + z
z ~ 2
y ~ 2z + k
]

sorted_eq = topsort_equations(eqs, [x, y, z, k])

ref_eq = [
z ~ 2
y ~ 2z + k
x ~ y + z
]
@test ref_eq == sorted_eq

@test_throws ArgumentError topsort_equations([
x ~ y + z
z ~ 2
y ~ 2z + x
], [x, y, z, k])

@parameters t σ ρ β
@variables x(t) y(t) z(t) a(t) u(t) F(t)
Expand All @@ -21,7 +44,7 @@ reduced_eqs = [
D(x) ~ σ * (y - x),
D(y) ~ x*(ρ-z)-y + 1,
0 ~ sin(z) - x + y,
sin(u) ~ x + y,
0 ~ x + y - sin(u),
]
test_equal.(equations(lorenz1_aliased), reduced_eqs)
test_equal.(states(lorenz1_aliased), [u, x, y, z])
Expand Down Expand Up @@ -81,7 +104,7 @@ aliased_flattened_system = alias_elimination(flattened_system)
]) |> isempty

reduced_eqs = [
lorenz2.y ~ a + lorenz1.x, # irreducible by alias elimination
0 ~ a + lorenz1.x - lorenz2.y, # irreducible by alias elimination
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
Expand Down Expand Up @@ -115,22 +138,28 @@ let
test_equal.(asys.observed, [y ~ x])
end

# issue #716
# issue #724 and #716
let
@parameters t
D = Differential(t)
@variables x(t), u(t), y(t)
@parameters a, b, c, d
ol = ODESystem([D(x) ~ a * x + b * u, y ~ c * x], t, name=:ol)
ol = ODESystem([D(x) ~ a * x + b * u; y ~ c * x + d * u], t, pins=[u], name=:ol)
@variables u_c(t), y_c(t)
@parameters k_P
pc = ODESystem(Equation[], t, pins=[y_c], observed = [u_c ~ k_P * y_c], name=:pc)
pc = ODESystem(Equation[u_c ~ k_P * y_c], t, pins=[y_c], name=:pc)
connections = [
ol.u ~ pc.u_c,
y_c ~ ol.y
]
ol.u ~ pc.u_c,
pc.y_c ~ ol.y
]
connected = ODESystem(connections, t, systems=[ol, pc])

@test equations(connected) isa Vector{Equation}
@test_nowarn flatten(connected)
sys = flatten(connected)
reduced_sys = alias_elimination(sys)
ref_eqs = [
D(ol.x) ~ ol.a*ol.x + ol.b*pc.u_c
0 ~ ol.c*ol.x + ol.d*pc.u_c - ol.y
0 ~ pc.k_P*ol.y - pc.u_c
]
@test ref_eqs == equations(reduced_sys)
end