diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 4142f14240..7d9d8d96c2 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -708,17 +708,27 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) prev_r = -1 for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop reach₌ = Pair{Int, Int}[] + # `r` is aliased to its equality aliases r === nothing || for n in neighbors(eqg, r) (n == r || is_diff_edge(r, n)) && continue c = get_weight(eqg, r, n) push!(reach₌, c => n) end + # `r` is aliased to its previous differentiation level's aliases' + # derivative if (n = length(diff_aliases)) >= 1 as = diff_aliases[n] for (c, a) in as (da = var_to_diff[a]) === nothing && continue da === r && continue push!(reach₌, c => da) + # `r` is aliased to its previous differentiation level's + # aliases' derivative's equality aliases + r === nothing || for n in neighbors(eqg, da) + (n == da || n == prev_r || is_diff_edge(prev_r, n)) && continue + c′ = get_weight(eqg, da, n) + push!(reach₌, c * c′ => n) + end end end if r === nothing diff --git a/test/reduction.jl b/test/reduction.jl index e534c7aa30..918254a540 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -302,3 +302,25 @@ ss = alias_elimination(sys) @test length(equations(ss)) == length(states(ss)) == 1 ss = structural_simplify(sys) @test length(equations(ss)) == length(states(ss)) == 2 + +@variables t +vars = @variables x(t) y(t) k(t) z(t) zₜ(t) ddx(t) +D = Differential(t) +eqs = [D(D(x)) ~ ddx + ddx ~ y + D(x) ~ z + D(z) ~ zₜ + D(zₜ) ~ sin(t) + D(x) ~ D(k) + D(D(D(x))) ~ sin(t)] +@named sys = ODESystem(eqs, t, vars, []) +state = TearingState(sys); +ag, mm, complete_ag, complete_mm = ModelingToolkit.alias_eliminate_graph!(state) +fullvars = state.fullvars +aliases = [] +for (v, (c, a)) in complete_ag + push!(aliases, fullvars[v] => c == 0 ? 0 : c * fullvars[a]) +end +ref_aliases = [D(k) => D(x); z => D(x); D(z) => D(D(x)); zₜ => D(D(x)); ddx => D(D(x)); + y => D(D(x)); D(zₜ) => D(D(D(x)))] +@test Set(aliases) == Set(ref_aliases)