Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ end
alias_elimination(sys) = alias_elimination!(TearingState(sys; quick_cancel = true))
function alias_elimination!(state::TearingState)
sys = state.sys
complete!(state.structure)
ag, mm, updated_diff_vars = alias_eliminate_graph!(state)
ag === nothing && return sys

Expand All @@ -52,8 +53,20 @@ function alias_elimination!(state::TearingState)
end

subs = Dict()
# If we encounter y = -D(x), then we need to expand the derivative when
# D(y) appears in the equation, so that D(-D(x)) becomes -D(D(x)).
to_expand = Int[]
diff_to_var = invview(var_to_diff)
for (v, (coeff, alias)) in pairs(ag)
subs[fullvars[v]] = iszero(coeff) ? 0 : coeff * fullvars[alias]
if coeff == -1
# if `alias` is like -D(x)
diff_to_var[alias] === nothing && continue
# if `v` is like y, and D(y) also exists
(dv = var_to_diff[v]) === nothing && continue
# all equations that contains D(y) needs to be expanded.
append!(to_expand, 𝑑neighbors(graph, dv))
end
end

dels = Int[]
Expand All @@ -72,11 +85,29 @@ function alias_elimination!(state::TearingState)
end
end
deleteat!(eqs, sort!(dels))
old_to_new = Vector{Int}(undef, length(var_to_diff))
idx = 0
cursor = 1
ndels = length(dels)
for (i, e) in enumerate(old_to_new)
if cursor <= ndels && i == dels[cursor]
cursor += 1
old_to_new[i] = -1
continue
end
idx += 1
old_to_new[i] = idx
end

for (ieq, eq) in enumerate(eqs)
eqs[ieq] = substitute(eq, subs)
end

for old_ieq in to_expand
ieq = old_to_new[old_ieq]
eqs[ieq] = expand_derivatives(eqs[ieq])
end

newstates = []
diff_to_var = invview(var_to_diff)
for j in eachindex(fullvars)
Expand Down
14 changes: 14 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,3 +849,17 @@ let
D(sys.x) ~ sys.v]
@test isequal(full_equations(sys_simp), true_eqs)
end

let
@variables t
@variables x(t) = 1
@variables y(t) = 1
@parameters pp = -1
der = Differential(t)
@named sys4 = ODESystem([der(x) ~ -y; der(y) ~ 1 - y + x], t)
as = alias_elimination(sys4)
@test length(equations(as)) == 1
@test isequal(equations(as)[1].lhs, -der(der(x)))
# TODO: maybe do not emit x_t
@test_nowarn sys4s = structural_simplify(sys4)
end