Skip to content

Commit

Permalink
Merge pull request #1900 from SciML/myb/morealias
Browse files Browse the repository at this point in the history
Add one missing case in the complete alias generation
  • Loading branch information
YingboMa committed Oct 22, 2022
2 parents 00590bf + d8e0ac9 commit 2905e79
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,17 +708,27 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
prev_r = -1
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
reach₌ = Pair{Int, Int}[]
# `r` is aliased to its equality aliases
r === nothing || for n in neighbors(eqg, r)
(n == r || is_diff_edge(r, n)) && continue
c = get_weight(eqg, r, n)
push!(reach₌, c => n)
end
# `r` is aliased to its previous differentiation level's aliases'
# derivative
if (n = length(diff_aliases)) >= 1
as = diff_aliases[n]
for (c, a) in as
(da = var_to_diff[a]) === nothing && continue
da === r && continue
push!(reach₌, c => da)
# `r` is aliased to its previous differentiation level's
# aliases' derivative's equality aliases
r === nothing || for n in neighbors(eqg, da)
(n == da || n == prev_r || is_diff_edge(prev_r, n)) && continue
c′ = get_weight(eqg, da, n)
push!(reach₌, c * c′ => n)
end
end
end
if r === nothing
Expand Down
22 changes: 22 additions & 0 deletions test/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,25 @@ ss = alias_elimination(sys)
@test length(equations(ss)) == length(states(ss)) == 1
ss = structural_simplify(sys)
@test length(equations(ss)) == length(states(ss)) == 2

@variables t
vars = @variables x(t) y(t) k(t) z(t) zₜ(t) ddx(t)
D = Differential(t)
eqs = [D(D(x)) ~ ddx
ddx ~ y
D(x) ~ z
D(z) ~ zₜ
D(zₜ) ~ sin(t)
D(x) ~ D(k)
D(D(D(x))) ~ sin(t)]
@named sys = ODESystem(eqs, t, vars, [])
state = TearingState(sys);
ag, mm, complete_ag, complete_mm = ModelingToolkit.alias_eliminate_graph!(state)
fullvars = state.fullvars
aliases = []
for (v, (c, a)) in complete_ag
push!(aliases, fullvars[v] => c == 0 ? 0 : c * fullvars[a])
end
ref_aliases = [D(k) => D(x); z => D(x); D(z) => D(D(x)); zₜ => D(D(x)); ddx => D(D(x));
y => D(D(x)); D(zₜ) => D(D(D(x)))]
@test Set(aliases) == Set(ref_aliases)

0 comments on commit 2905e79

Please sign in to comment.