diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 288cc9dd56..3d68a4abce 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -35,6 +35,7 @@ end alias_elimination(sys) = alias_elimination!(TearingState(sys; quick_cancel = true)) function alias_elimination!(state::TearingState) sys = state.sys + complete!(state.structure) ag, mm, updated_diff_vars = alias_eliminate_graph!(state) ag === nothing && return sys @@ -52,8 +53,20 @@ function alias_elimination!(state::TearingState) end subs = Dict() + # If we encounter y = -D(x), then we need to expand the derivative when + # D(y) appears in the equation, so that D(-D(x)) becomes -D(D(x)). + to_expand = Int[] + diff_to_var = invview(var_to_diff) for (v, (coeff, alias)) in pairs(ag) subs[fullvars[v]] = iszero(coeff) ? 0 : coeff * fullvars[alias] + if coeff == -1 + # if `alias` is like -D(x) + diff_to_var[alias] === nothing && continue + # if `v` is like y, and D(y) also exists + (dv = var_to_diff[v]) === nothing && continue + # all equations that contains D(y) needs to be expanded. + append!(to_expand, đť‘‘neighbors(graph, dv)) + end end dels = Int[] @@ -72,11 +85,29 @@ function alias_elimination!(state::TearingState) end end deleteat!(eqs, sort!(dels)) + old_to_new = Vector{Int}(undef, length(var_to_diff)) + idx = 0 + cursor = 1 + ndels = length(dels) + for (i, e) in enumerate(old_to_new) + if cursor <= ndels && i == dels[cursor] + cursor += 1 + old_to_new[i] = -1 + continue + end + idx += 1 + old_to_new[i] = idx + end for (ieq, eq) in enumerate(eqs) eqs[ieq] = substitute(eq, subs) end + for old_ieq in to_expand + ieq = old_to_new[old_ieq] + eqs[ieq] = expand_derivatives(eqs[ieq]) + end + newstates = [] diff_to_var = invview(var_to_diff) for j in eachindex(fullvars) diff --git a/test/odesystem.jl b/test/odesystem.jl index d12d716c27..16e76990e7 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -849,3 +849,17 @@ let D(sys.x) ~ sys.v] @test isequal(full_equations(sys_simp), true_eqs) end + +let + @variables t + @variables x(t) = 1 + @variables y(t) = 1 + @parameters pp = -1 + der = Differential(t) + @named sys4 = ODESystem([der(x) ~ -y; der(y) ~ 1 - y + x], t) + as = alias_elimination(sys4) + @test length(equations(as)) == 1 + @test isequal(equations(as)[1].lhs, -der(der(x))) + # TODO: maybe do not emit x_t + @test_nowarn sys4s = structural_simplify(sys4) +end