diff --git a/src/bipartite_graph.jl b/src/bipartite_graph.jl index 96f728f1e1..5b9ca0924b 100644 --- a/src/bipartite_graph.jl +++ b/src/bipartite_graph.jl @@ -49,6 +49,10 @@ end function Matching(m::Int) Matching{Unassigned}(Union{Int, Unassigned}[unassigned for _ in 1:m], nothing) end +function Matching{U}(m::Int) where {U} + Matching{Union{Unassigned, U}}(Union{Int, Unassigned, U}[unassigned for _ in 1:m], + nothing) +end Base.size(m::Matching) = Base.size(m.match) Base.getindex(m::Matching, i::Integer) = m.match[i] @@ -65,9 +69,9 @@ function Base.setindex!(m::Matching{U}, v::Union{Integer, U}, i::Integer) where return m.match[i] = v end -function Base.push!(m::Matching{U}, v::Union{Integer, U}) where {U} +function Base.push!(m::Matching, v) push!(m.match, v) - if v !== unassigned && m.inv_match !== nothing + if v isa Integer && m.inv_match !== nothing m.inv_match[v] = length(m.match) end end @@ -299,8 +303,8 @@ vertices, subject to the constraint that vertices for which `srcfilter` or `dstf return `false` may not be matched. """ function maximal_matching(g::BipartiteGraph, srcfilter = vsrc -> true, - dstfilter = vdst -> true) - matching = Matching(ndsts(g)) + dstfilter = vdst -> true, ::Type{U} = Unassigned) where {U} + matching = Matching{U}(ndsts(g)) foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc construct_augmenting_path!(matching, g, vsrc, dstfilter) end diff --git a/src/structural_transformation/bipartite_tearing/modia_tearing.jl b/src/structural_transformation/bipartite_tearing/modia_tearing.jl index 0833bfc9ef..e587827b44 100644 --- a/src/structural_transformation/bipartite_tearing/modia_tearing.jl +++ b/src/structural_transformation/bipartite_tearing/modia_tearing.jl @@ -35,10 +35,10 @@ function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, va return nothing end -function tear_graph_modia(structure::SystemStructure; varfilter = v -> true, - eqfilter = eq -> true) +function tear_graph_modia(structure::SystemStructure, ::Type{U} = Unassigned; + varfilter = v -> true, eqfilter = eq -> true) where {U} @unpack graph, solvable_graph = structure - var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter)) + var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter, U)) var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching) for vars in var_sccs diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index cf97e27263..60f66dbb0e 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -146,53 +146,48 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match var_eq_matching end -function dummy_derivative_graph!(state::TransformationState, jac = nothing) +function dummy_derivative_graph!(state::TransformationState, jac = nothing; kwargs...) + state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...) var_eq_matching = complete(pantelides!(state)) complete!(state.structure) dummy_derivative_graph!(state.structure, var_eq_matching, jac) end -function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac) - @unpack eq_to_diff, var_to_diff, graph = structure - diff_to_eq = invview(eq_to_diff) - diff_to_var = invview(var_to_diff) - invgraph = invview(graph) - - neqs = nsrcs(graph) - eqlevel = zeros(Int, neqs) +function compute_diff_level(diff_to_x) + nxs = length(diff_to_x) + xlevel = zeros(Int, nxs) maxlevel = 0 - for i in 1:neqs + for i in 1:nxs level = 0 - eq = i - while diff_to_eq[eq] !== nothing - eq = diff_to_eq[eq] + x = i + while diff_to_x[x] !== nothing + x = diff_to_x[x] level += 1 end maxlevel = max(maxlevel, level) - eqlevel[i] = level + xlevel[i] = level end + return xlevel, maxlevel +end - nvars = ndsts(graph) - varlevel = zeros(Int, nvars) - for i in 1:nvars - level = 0 - var = i - while diff_to_var[var] !== nothing - var = diff_to_var[var] - level += 1 - end - maxlevel = max(maxlevel, level) - varlevel[i] = level - end +function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac) + @unpack eq_to_diff, var_to_diff, graph = structure + diff_to_eq = invview(eq_to_diff) + diff_to_var = invview(var_to_diff) + invgraph = invview(graph) + + eqlevel, _ = compute_diff_level(diff_to_eq) + varlevel, _ = compute_diff_level(diff_to_var) var_sccs = find_var_sccs(graph, var_eq_matching) - eqcolor = falses(neqs) + eqcolor = falses(nsrcs(graph)) dummy_derivatives = Int[] col_order = Int[] + nvars = ndsts(graph) for vars in var_sccs eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] !== unassigned] isempty(eqs) && continue - maxlevel = maximum(map(x -> eqlevel[x], eqs)) + maxlevel = maximum(Base.Fix1(getindex, eqlevel), eqs) iszero(maxlevel) && continue rank_matching = Matching(nvars) @@ -220,8 +215,10 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja else rank = 0 for var in vars + # We need `invgraph` here because we are matching from + # variables to equations. pathfound = construct_augmenting_path!(rank_matching, invgraph, var, - eq -> eq in eqs_set, eqcolor) + Base.Fix2(in, eqs_set), eqcolor) pathfound || continue push!(dummy_derivatives, var) rank += 1 @@ -239,5 +236,34 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja end end - dummy_derivatives + dummy_derivatives_set = BitSet(dummy_derivatives) + # We can eliminate variables that are not a selected state (differential + # variables). Selected states are differentiated variables that are not + # dummy derivatives. + can_eliminate = let var_to_diff = var_to_diff, dummy_derivatives_set = dummy_derivatives_set + + v -> begin + dv = var_to_diff[v] + dv === nothing || dv in dummy_derivatives_set + end + end + + # We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with + # actually differentiated variables. + isdiffed = let diff_to_var = diff_to_var, dummy_derivatives_set = dummy_derivatives_set + v -> diff_to_var[v] !== nothing && !(v in dummy_derivatives_set) + end + should_consider = let graph = graph, isdiffed = isdiffed + eq -> !any(isdiffed, 𝑠neighbors(graph, eq)) + end + + var_eq_matching = tear_graph_modia(structure, Union{Unassigned, SelectedState}; + varfilter = can_eliminate, + eqfilter = should_consider) + for v in eachindex(var_eq_matching) + can_eliminate(v) && continue + var_eq_matching[v] = SelectedState() + end + + return var_eq_matching end diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 9ffcfb343f..ea9dcfe7c3 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -48,7 +48,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int) # Analyze the new equation and update the graph/solvable_graph # First, copy the previous incidence and add the derivative terms. # That's a superset of all possible occurrences. find_solvables! will - # remove those that doen't actually occur. + # remove those that doesn't actually occur. eq_diff = length(equations(ts)) for var in 𝑠neighbors(s.graph, ieq) add_edge!(s.graph, eq_diff, var) @@ -115,28 +115,153 @@ function solve_equation(eq, var, simplify) var ~ rhs end -function tearing_reassemble(state::TearingState, var_eq_matching; simplify = false) +# From the index of `D(x)` find the equation `D(x) ~ x_t` and the variable +# `x_t`. +function has_order_lowering_eq_var(eqs, fullvars, graph, var_to_diff, dx_idx)::Union{Nothing, NTuple{2, Int}} + diff_to_var = invview(var_to_diff) + diff_to_var[dx_idx] === nothing && return nothing + + dx = fullvars[dx_idx] + for eq in 𝑑neighbors(graph, dx_idx) + vs = 𝑠neighbors(graph, eq) + length(vs) == 2 || continue + maybe_x_t_idx = vs[1] == dx_idx ? vs[2] : vs[1] + # TODO: should we follow the differentiation chain? I.e. recurse until + # all reachable variables are explored or `diff_to_var[maybe_x_t_idx] === nothing` + diff_to_var[maybe_x_t_idx] === nothing || continue + maybe_x_t = fullvars[maybe_x_t_idx] + difference = (eqs[eq].lhs - eqs[eq].rhs) - (dx - maybe_x_t) + # if `eq` is in the form of `D(x) ~ x_t` + if ModelingToolkit._iszero(difference) + # TODO: reduce systems with multiple order lowering `eq` and `var` + # as well. + return eq, maybe_x_t_idx + end + end + return nothing +end + +function var2var_t_map(state::TearingState) + fullvars = state.fullvars + @unpack var_to_diff, graph = state.structure + eqs = equations(state) + @info "" eqs + var2var_t = Vector{Union{Nothing, NTuple{2, Int}}}(undef, ndsts(graph)) + for v in 1:ndsts(graph) + var2var_t[v] = has_order_lowering_eq_var(eqs, fullvars, graph, var_to_diff, v) + end + var2var_t +end + +function substitute_vars!(graph::BipartiteGraph, subs, cache=Int[], callback! = nothing; exclude = ()) + for su in subs + su === nothing && continue + v, v′ = su + eqs = 𝑑neighbors(graph, v) + # Note that the iterator is not robust under deletion and + # insertion. Hence, we have a copy here. + resize!(cache, length(eqs)) + for eq in copyto!(cache, eqs) + eq in exclude && continue + rem_edge!(graph, eq, v) + add_edge!(graph, eq, v′) + callback! !== nothing && callback!(eq, su) + end + end + graph +end + +function tearing_reassemble(state::TearingState, var_eq_matching, var2var_t = var2var_t_map(state); simplify = false) fullvars = state.fullvars @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure neweqs = collect(equations(state)) - - ### Replace derivatives of non-selected states by dumy derivatives - dummy_subs = Dict() - for var in 1:length(fullvars) - invview(var_to_diff)[var] === nothing && continue - if var_eq_matching[invview(var_to_diff)[var]] !== SelectedState() - fullvar = fullvars[var] - subst_fullvar = tearing_sub(fullvar, dummy_subs, simplify) - dummy_subs[fullvar] = fullvars[var] = diff2term(unwrap(subst_fullvar)) - var_to_diff[invview(var_to_diff)[var]] = nothing + # substitution utilities + idx_buffer = Int[] + sub_callback! = let eqs = neweqs, fullvars = fullvars + (ieq, s) -> begin + neweq = substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]]) + @info "substitute" eqs[ieq] neweq + eqs[ieq] = neweq end end - if !isempty(dummy_subs) - neweqs = map(neweqs) do eq - 0 ~ tearing_sub(eq.rhs - eq.lhs, dummy_subs, simplify) + + # Terminology and Definition: + # + # A general DAE is in the form of `F(u'(t), u(t), p, t) == 0`. We can + # characterize variables in `u(t)` into two classes: differential variables + # (denoted `v(t)`) and algebraic variables (denoted `z(t)`). Differential + # variables are marked as `SelectedState` and they are differentiated in the + # DAE system, i.e. `v'(t)` are all the variables in `u'(t)` that actually + # appear in the system. Algebraic variables are variables that are not + # differential variables. + # + # Dummy derivatives may determine that some differential variables are + # algebraic variables in disguise. The derivative of such variables are + # called dummy derivatives. + + # Step 1: + # Replace derivatives of non-selected states by dummy derivatives + + null_eq = 0 ~ 0 + @info "Before" neweqs + @info "" fullvars + removed_eqs = Int[] + removed_vars = Int[] + diff_to_var = invview(var_to_diff) + var2idx = Dict(reverse(en) for en in enumerate(fullvars)) + for var in 1:length(fullvars) + dv = var_to_diff[var] + dv === nothing && continue + if var_eq_matching[var] !== SelectedState() + @warn "processing" fullvars[dv] + dd = fullvars[dv] + # convert `D(x)` to `x_t` (don't rely on the specific spelling of + # the name) + eq_var_t = var2var_t[dv] + idx = findfirst(x->x !== nothing && x[2] == var, var2var_t) + has_dummy_var = idx !== nothing && var_to_diff[var2var_t[idx][2]] !== nothing && var_eq_matching[idx] !== SelectedState() + if eq_var_t !== nothing # if we already have `v_t` + eq_idx, v_t = eq_var_t + push!(removed_eqs, eq_idx) + push!(removed_vars, dv) + substitute_vars!(graph, ((dv => v_t),), idx_buffer, sub_callback!; exclude = eq_idx) + substitute_vars!(solvable_graph, ((dv => v_t),), idx_buffer; exclude = eq_idx) + for g in (graph, solvable_graph) + vs = 𝑠neighbors(g, eq_idx) + resize!(idx_buffer, length(vs)) + for v in copyto!(idx_buffer, vs) + rem_edge!(g, eq_idx, v) + end + end + neweqs[eq_idx] = null_eq # TODO: we don't have to do this + #elseif has_dummy_var + # #eq_idx, v_t = eq_var_t + # #push!(removed_eqs, eq_idx) + # push!(removed_vars, dv) + # substitute_vars!(graph, ((dv => idx),), idx_buffer, sub_callback!) + else + # TODO: figure this out structurally + v_t = diff2term(unwrap(dd)) + v_t_idx = get(var2idx, v_t, nothing) + if v_t_idx isa Int + substitute_vars!(graph, ((dv => v_t_idx),), idx_buffer, sub_callback!) + else + for eq in 𝑑neighbors(graph, dv) + neweqs[eq] = substitute(neweqs[eq], fullvars[dv] => v_t) + end + fullvars[dv] = v_t + end + end + # update the structural information + diff_to_var[dv] = nothing end end + @info "" fullvars fullvars[removed_vars] + + # `SelectedState` information is no longer needed past here. State selection + # is done. All non-differentiated variables are algebraic variables, and all + # variables that appear differentiated are differential variables. ### extract partition information is_solvable(eq, iv) = isa(eq, Int) && BipartiteEdge(eq, iv) in solvable_graph @@ -145,11 +270,198 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal solved_variables = Int[] # if var is like D(x) - function isdiffvar(var) - invview(var_to_diff)[var] !== nothing && - var_eq_matching[invview(var_to_diff)[var]] === SelectedState() + isdiffvar = let diff_to_var = diff_to_var + var -> diff_to_var[var] !== nothing + end + + # There are three cases where we want to generate new variables to convert + # the system into first order (semi-implicit) ODEs. + # + # 1. To first order: + # Whenever higher order differentiated variable like `D(D(D(x)))` appears, + # we introduce new variables `x_t`, `x_tt`, and `x_ttt` and new equations + # ``` + # D(x_tt) = x_ttt + # D(x_t) = x_tt + # D(x) = x_t + # ``` + # and replace `D(x)` to `x_t`, `D(D(x))` to `x_tt`, and `D(D(D(x)))` to + # `x_ttt`. + # + # 2. To implicit to semi-implicit ODEs: + # 2.1: Unsolvable derivative: + # If one derivative variable `D(x)` is unsolvable in all the equations it + # appears in, then we introduce a new variable `x_t`, a new equation + # ``` + # D(x) ~ x_t + # ``` + # and replace all other `D(x)` to `x_t`. + # + # 2.2: Solvable derivative: + # If one derivative variable `D(x)` is solvable in at least one of the + # equations it appears in, then we introduce a new variable `x_t`. One of + # the solvable equations must be in the form of `0 ~ L(D(x), u...)` and + # there exists a function `l` such that `D(x) ~ l(u...)`. We should replace + # it to + # ``` + # 0 ~ x_t - l(u...) + # D(x) ~ x_t + # ``` + # and replace all other `D(x)` to `x_t`. + # + # Observe that we don't need to actually introduce a new variable `x_t`, as + # the above equations can be lowered to + # ``` + # x_t := l(u...) + # D(x) ~ x_t + # ``` + # where `:=` denotes assignment. + # + # As a final note, in all the above cases where we need to introduce new + # variables and equations, don't add them when they already exist. + + @info "After dummy der" neweqs + var_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(fullvars)) + if ModelingToolkit.has_iv(state.sys) + iv = get_iv(state.sys) + D = Differential(iv) + else + iv = D = nothing end + nvars = ndsts(graph) + processed = falses(nvars) + subinfo = NTuple{3, Int}[] + for i in 1:nvars + processed[i] && continue + + v = i + # descend to the bottom of differentiation chain + while diff_to_var[v] !== nothing + v = diff_to_var[v] + end + # `v` is now not differentiated at level 0. + diffvar = v + processed[v] = true + level = 0 + order = 0 + isimplicit = false + # ascend to the top of differentiation chain + while true + eqs_with_v = 𝑑neighbors(graph, v) + if !isempty(eqs_with_v) + order = level + isimplicit = length(eqs_with_v) > 1 || !is_solvable(only(eqs_with_v), v) + end + if v <= length(processed) + processed[v] = true + end + var_to_diff[v] === nothing && break + v = var_to_diff[v] + level += 1 + end + + # `diffvar` is a order `order` variable + (isimplicit || order > 1) || continue + + # add `D(t) ~ x_t` etc + subs = Dict() + ogx = x = fullvars[diffvar] # x + ogidx = xidx = diffvar + # We shouldn't apply substitution to `order_lowering_eqs` + order_lowering_eqs = BitSet() + for o in 1:order + # D(x) ~ x_t + ogidx = var_to_diff[ogidx] + + has_x_t = false + x_t_idx::Union{Nothing, Int} = nothing + dx_idx = var_to_diff[xidx] + if dx_idx === nothing + dx = D(x) + push!(fullvars, dx) + dx_idx = add_vertex!(var_to_diff) + add_vertex!(graph, DST) + add_vertex!(solvable_graph, DST) + @assert dx_idx == ndsts(graph) == length(fullvars) + push!(var_eq_matching, unassigned) + + var_to_diff[xidx] = dx_idx + else + dx = fullvars[dx_idx] + var_eq_matching[dx_idx] = unassigned + + for eq in 𝑑neighbors(graph, dx_idx) + vs = 𝑠neighbors(graph, eq) + length(vs) == 2 || continue + maybe_x_t_idx = vs[1] == dx_idx ? vs[2] : vs[1] + maybe_x_t = fullvars[maybe_x_t_idx] + difference = (neweqs[eq].lhs - neweqs[eq].rhs) - (dx - maybe_x_t) + # if `eq` is in the form of `D(x) ~ x_t` + if ModelingToolkit._iszero(difference) + x_t_idx = maybe_x_t_idx + x_t = maybe_x_t + eq_idx = eq + push!(order_lowering_eqs, eq_idx) + has_x_t = true + break + end + end + end + + if x_t_idx === nothing + x_t = ModelingToolkit.lower_varname(ogx, iv, o) + push!(fullvars, x_t) + x_t_idx = add_vertex!(var_to_diff) + add_vertex!(graph, DST) + add_vertex!(solvable_graph, DST) + @assert x_t_idx == ndsts(graph) == length(fullvars) + push!(var_eq_matching, unassigned) + end + x_t_idx::Int + + if !has_x_t + push!(neweqs, dx ~ x_t) + eq_idx = add_vertex!(eq_to_diff) + push!(order_lowering_eqs, eq_idx) + add_vertex!(graph, SRC) + add_vertex!(solvable_graph, SRC) + @assert eq_idx == nsrcs(graph) == length(neweqs) + + add_edge!(solvable_graph, eq_idx, x_t_idx) + add_edge!(solvable_graph, eq_idx, dx_idx) + add_edge!(graph, eq_idx, x_t_idx) + add_edge!(graph, eq_idx, dx_idx) + + end + # We use this info to substitute all `D(D(x))` or `D(x_t)` except + # the `D(D(x)) ~ x_tt` equation to `x_tt`. + # D(D(x)) D(x_t) x_tt + push!(subinfo, (ogidx, dx_idx, x_t_idx)) + + # D(x_t) ~ x_tt + x = x_t + xidx = x_t_idx + end + + # Go backward from high order to lower order so that we substitute + # something like `D(D(x)) -> x_tt` first, otherwise we get `D(x_t)` + # which would be hard to fix up before we finish lower the order of + # variable `x`. + for (ogidx, dx_idx, x_t_idx) in Iterators.reverse(subinfo) + # We need a loop here because both `D(D(x))` and `D(x_t)` need to be + # substituted to `x_tt`. + for idx in (ogidx, dx_idx) + subidx = ((idx => x_t_idx),) + substitute_vars!(graph, subidx, idx_buffer, sub_callback!; exclude = order_lowering_eqs) + substitute_vars!(solvable_graph, subidx, idx_buffer; exclude = order_lowering_eqs) + end + end + empty!(subinfo) + empty!(subs) + end + + @info "After implicit to semi-implicit" neweqs # Rewrite remaining equations in terms of solved variables function to_mass_matrix_form(ieq) eq = neweqs[ieq] @@ -158,7 +470,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal end rhs = eq.rhs if rhs isa Symbolic - # Check if the rhs is solvable in all state derivatives and if those + # Check if the RHS is solvable in all state derivatives and if those # the linear terms for them are all zero. If so, move them to the # LHS. dterms = [var for var in 𝑠neighbors(graph, ieq) if isdiffvar(var)] @@ -187,33 +499,53 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal end diffeq_idxs = BitSet() - diffeqs = Equation[] + final_eqs = Equation[] + var_rename = zeros(Int, length(var_eq_matching)) + subeqs = Equation[] + idx = 0 # Solve solvable equations for (iv, ieq) in enumerate(var_eq_matching) - is_solvable(ieq, iv) || continue - # We don't solve differential equations, but we will need to try to - # convert it into the mass matrix form. - # We cannot solve the differential variable like D(x) - if isdiffvar(iv) - push!(diffeqs, to_mass_matrix_form(ieq)) - push!(diffeq_idxs, ieq) - continue + if is_solvable(ieq, iv) + # We don't solve differential equations, but we will need to try to + # convert it into the mass matrix form. + # We cannot solve the differential variable like D(x) + if isdiffvar(iv) + # TODO: what if `to_mass_matrix_form(ieq)` returns `nothing`? + push!(final_eqs, to_mass_matrix_form(ieq)) + push!(diffeq_idxs, ieq) + var_rename[iv] = (idx += 1) + continue + end + eq = neweqs[ieq] + var = fullvars[iv] + residual = eq.lhs - eq.rhs + a, b, islinear = linear_expansion(residual, var) + # 0 ~ a * var + b + # var ~ -b/a + if ModelingToolkit._iszero(a) + push!(removed_eqs, ieq) + push!(removed_vars, iv) + else + rhs = -b/a + neweq = var ~ simplify ? Symbolics.simplify(rhs) : rhs + push!(subeqs, neweq) + push!(solved_equations, ieq) + push!(solved_variables, iv) + end + var_rename[iv] = -1 + else + var_rename[iv] = (idx += 1) end - push!(solved_equations, ieq) - push!(solved_variables, iv) end if isempty(solved_equations) - subeqs = Equation[] deps = Vector{Int}[] else subgraph = substitution_graph(graph, solved_equations, solved_variables, var_eq_matching) toporder = topological_sort_by_dfs(subgraph) - subeqs = Equation[solve_equation(neweqs[solved_equations[i]], - fullvars[solved_variables[i]], - simplify) for i in toporder] - # find the dependency of solved variables. we will need this for ODAEProblem + subeqs = subeqs[toporder] + # Find the dependency of solved variables. We will need this for ODAEProblem invtoporder = invperm(toporder) deps = [Int[invtoporder[n] for n in neighborhood(subgraph, j, Inf, dir = :in) if n != j] @@ -223,29 +555,39 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal # TODO: BLT sorting # Rewrite remaining equations in terms of solved variables solved_eq_set = BitSet(solved_equations) - neweqs = Equation[to_mass_matrix_form(ieq) - for ieq in 1:length(neweqs) - if !(ieq in diffeq_idxs || ieq in solved_eq_set)] - filter!(!isnothing, neweqs) - prepend!(neweqs, diffeqs) + for ieq in 1:length(neweqs) + (ieq in diffeq_idxs || ieq in solved_eq_set) && continue + maybe_eq = to_mass_matrix_form(ieq) + maybe_eq === nothing || push!(final_eqs, maybe_eq) + end + neweqs = final_eqs # Contract the vertices in the structure graph to make the structure match # the new reality of the system we've just created. + # + # TODO: fix ordering and remove equations graph = contract_variables(graph, var_eq_matching, solved_variables) # Update system - active_vars = setdiff(BitSet(1:length(fullvars)), solved_variables) + solved_variables_set = BitSet(solved_variables) + active_vars = setdiff!(setdiff(BitSet(1:length(fullvars)), solved_variables_set), removed_vars) + new_var_to_diff = complete(DiffGraph(length(active_vars))) + idx = 0 + for (v, d) in enumerate(var_to_diff) + v′ = var_rename[v] + (v′ > 0 && d !== nothing) || continue + d′ = var_rename[d] + new_var_to_diff[v′] = d′ > 0 ? d′ : nothing + end @set! state.structure.graph = graph + # Note that `eq_to_diff` is not updated + @set! state.structure.var_to_diff = new_var_to_diff @set! state.fullvars = [v for (i, v) in enumerate(fullvars) if i in active_vars] sys = state.sys @set! sys.eqs = neweqs - function isstatediff(i) - var_eq_matching[i] !== SelectedState() && invview(var_to_diff)[i] !== nothing && - var_eq_matching[invview(var_to_diff)[i]] === SelectedState() - end - @set! sys.states = [fullvars[i] for i in active_vars if !isstatediff(i)] + @set! sys.states = [fullvars[i] for i in active_vars if diff_to_var[i] === nothing] @set! sys.observed = [observed(sys); subeqs] @set! sys.substitutions = Substitutions(subeqs, deps) @set! state.sys = sys @@ -300,17 +642,14 @@ end """ dummy_derivative(sys) -Perform index reduction and use the dummy derivative techinque to ensure that +Perform index reduction and use the dummy derivative technique to ensure that the system is balanced. """ -function dummy_derivative(sys, state = TearingState(sys)) +function dummy_derivative(sys, state = TearingState(sys); simplify = false, kwargs...) function jac(eqs, vars) symeqs = EquationsView(state)[eqs] Symbolics.jacobian((x -> x.rhs).(symeqs), state.fullvars[vars]) end - dds = dummy_derivative_graph!(state, jac) - symdds = Symbolics.diff2term.(state.fullvars[dds]) - subs = Dict(state.fullvars[dd] => symdds[i] for (i, dd) in enumerate(dds)) - @set! sys.eqs = substitute.(EquationsView(state), (subs,)) - @set! sys.states = [states(sys); symdds] + var_eq_matching = dummy_derivative_graph!(state, jac; kwargs...) + tearing_reassemble(state, var_eq_matching; simplify = simplify) end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index e2cba876ff..6c93265c20 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1008,12 +1008,7 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...) state = inputs_to_parameters!(state) sys = state.sys check_consistency(state) - if sys isa ODESystem - sys = dae_order_lowering(dummy_derivative(sys, state)) - end - state = TearingState(sys) - find_solvables!(state; kwargs...) - sys = tearing_reassemble(state, tearing(state), simplify = simplify) + sys = dummy_derivative(sys, state) fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)] @set! sys.observed = topsort_equations(observed(sys), fullstates) invalidate_cache!(sys) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 53ca5a610f..a9d7eee845 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -144,14 +144,16 @@ function Base.getindex(ag::AliasGraph, i::Integer) r = ag.aliasto[i] r === nothing && throw(KeyError(i)) coeff, var = (sign(r), abs(r)) + nc = coeff + av = var if var in keys(ag) # Amortized lookup. Check if since we last looked this up, our alias was # itself aliased. If so, just adjust the alias table. ac, av = ag[var] nc = ac * coeff - ag.aliasto[var] = nc > 0 ? av : -av + ag.aliasto[i] = nc > 0 ? av : -av end - return (coeff, var) + return (nc, av) end function Base.iterate(ag::AliasGraph, state...) diff --git a/test/alias.jl b/test/alias.jl new file mode 100644 index 0000000000..6cd192543e --- /dev/null +++ b/test/alias.jl @@ -0,0 +1,20 @@ +using Test +using ModelingToolkit: AliasGraph + +ag = AliasGraph(10) +ag[1] = 1 => 2 +ag[2] = -1 => 3 +ag[4] = -1 => 1 +ag[5] = -1 => 4 +for _ in 1:5 # check ag is robust + @test ag[1] == (-1, 3) + @test ag[2] == (-1, 3) + @test ag[4] == (1, 3) + @test ag[5] == (-1, 3) +end + +@test 1 in keys(ag) +@test 2 in keys(ag) +@test !(3 in keys(ag)) +@test 4 in keys(ag) +@test 5 in keys(ag) diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 5a31a7bb68..fcddfe8d53 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -118,12 +118,12 @@ u = [rand()] @variables u(t) [input = true] function Mass(; name, m = 1.0, p = 0, v = 0) - @variables y(t) [output = true] + @variables y(t)=0 [output = true] ps = @parameters m = m sts = @variables pos(t)=p vel(t)=v eqs = [D(pos) ~ vel y ~ pos] - ODESystem(eqs, t, [pos, vel], ps; name) + ODESystem(eqs, t, [pos, vel, y], ps; name) end function Spring(; name, k = 1e4) @@ -166,15 +166,15 @@ eqs = [connect_sd(sd, mass1, mass2) f, dvs, ps = ModelingToolkit.generate_control_function(model, expression = Val{false}, simplify = true) -@test length(dvs) == 4 @test length(ps) == length(parameters(model)) p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps) x = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), dvs) u = [rand()] -@test f[1](x, u, p, 1) == [u; 0; 0; 0] +out = f[1](x, u, p, 1) +@test out[1] == u[1] && iszero(out[2:end]) @parameters t @variables x(t) u(t) [input = true] eqs = [Differential(t)(x) ~ u] @named sys = ODESystem(eqs, t) -structural_simplify(sys) +@test_nowarn structural_simplify(sys) diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index 141637a8d1..179f7f64fd 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -195,7 +195,7 @@ let u[4] ~ 1] sys = NonlinearSystem(eqs, collect(u[1:4]), Num[], defaults = Dict([]), name = :test) - prob = NonlinearProblem(sys, ones(length(sys.states))) + prob = NonlinearProblem(sys, ones(length(states(sys)))) sol = NonlinearSolve.solve(prob, NewtonRaphson()) diff --git a/test/runtests.jl b/test/runtests.jl index 6f33dcf01d..820ac77a6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SafeTestsets, Test +@safetestset "AliasGraph Test" begin include("alias.jl") end @safetestset "Linear Algebra Test" begin include("linalg.jl") end @safetestset "AbstractSystem Test" begin include("abstractsystem.jl") end @safetestset "Variable scope tests" begin include("variable_scope.jl") end diff --git a/test/state_selection.jl b/test/state_selection.jl index 66eaee6b36..a0e3bf68d5 100644 --- a/test/state_selection.jl +++ b/test/state_selection.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, OrdinaryDiffEq, Test +using ModelingToolkit, OrdinaryDiffEq, IfElse, Test @variables t sts = @variables x1(t) x2(t) x3(t) x4(t) @@ -18,11 +18,10 @@ let dd = dummy_derivative(sys) has_dx2 |= D(x2) in vars || D(D(x2)) in vars end @test has_dx1 ⊻ has_dx2 # only one of x1 and x2 can be a dummy derivative - @test length(states(dd)) == length(equations(dd)) == 9 - @test length(states(structural_simplify(dd))) < 9 + @test length(states(dd)) == length(equations(dd)) < 9 end -let pss = partial_state_selection(sys) +@test_skip let pss = partial_state_selection(sys) @test length(equations(pss)) == 1 @test length(states(pss)) == 2 @test length(equations(ode_order_lowering(pss))) == 2 @@ -122,14 +121,16 @@ let end @named system = System(L = 10) - @unpack supply_pipe = system + @unpack supply_pipe, return_pipe = system sys = structural_simplify(system) - u0 = [system.supply_pipe.v => 0.1, system.return_pipe.v => 0.1, D(supply_pipe.v) => 0.0] - # This is actually an implicit DAE system - @test_throws Any ODEProblem(sys, u0, (0.0, 10.0), []) - @test_throws Any ODAEProblem(sys, u0, (0.0, 10.0), []) - prob = DAEProblem(sys, D.(states(sys)) .=> 0.0, u0, (0.0, 10.0), []) - @test solve(prob, DFBDF()).retcode == :Success + u0 = [system.supply_pipe.v => 0.1, system.return_pipe.v => 0.1, D(supply_pipe.v) => 0.0, + D(return_pipe.fluid_port_a.m) => 0.0] + prob1 = ODEProblem(sys, u0, (0.0, 10.0), []) + prob2 = ODAEProblem(sys, u0, (0.0, 10.0), []) + prob3 = DAEProblem(sys, D.(states(sys)) .=> 0.0, u0, (0.0, 10.0), []) + @test solve(prob1, FBDF()).retcode == :Success + @test solve(prob2, FBDF()).retcode == :Success + @test solve(prob3, DFBDF()).retcode == :Success end # 1537 @@ -189,7 +190,96 @@ let rho_3 => 1.3 mo_1 => 0 mo_2 => 1 - mo_3 => 2] - prob = ODAEProblem(sys, u0, (0.0, 0.1)) - @test solve(prob, FBDF()).retcode == :Success + mo_3 => 2 + Ek_3 => 3] + prob1 = ODEProblem(sys, u0, (0.0, 0.1)) + prob2 = ODAEProblem(sys, u0, (0.0, 0.1)) + @test solve(prob1, FBDF()).retcode == :Success + @test_broken solve(prob2, FBDF()).retcode == :Success +end + +let + # constant parameters ---------------------------------------------------- + A_1f = 0.0908 + A_2f = 0.036 + p_1f_0 = 1.8e6 + p_2f_0 = p_1f_0 * A_1f / A_2f + m_total = 3245 + K1 = 4.60425e-5 + K2 = 0.346725 + K3 = 0 + density = 876 + bulk = 1.2e9 + l_1f = 0.7 + x_f_fullscale = 0.025 + p_s = 200e5 + # -------------------------------------------------------------------------- + + # modelingtoolkit setup ---------------------------------------------------- + @parameters t + params = @parameters l_2f=0.7 damp=1e3 + vars = @variables begin + p1(t) + p2(t) + dp1(t) = 0 + dp2(t) = 0 + xf(t) = 0 + rho1(t) + rho2(t) + drho1(t) = 0 + drho2(t) = 0 + V1(t) + V2(t) + dV1(t) = 0 + dV2(t) = 0 + w(t) = 0 + dw(t) = 0 + ddw(t) = 0 + end + D = Differential(t) + + defs = [p1 => p_1f_0 + p2 => p_2f_0 + rho1 => density * (1 + p_1f_0 / bulk) + rho2 => density * (1 + p_2f_0 / bulk) + V1 => l_1f * A_1f + V2 => l_2f * A_2f + D(p1) => dp1 + D(p2) => dp2 + D(w) => dw + D(dw) => ddw] + + # equations ------------------------------------------------------------------ + flow(x, dp) = K1 * abs(dp) * abs(x) + K2 * sqrt(abs(dp)) * abs(x) + K3 * abs(dp) * x^2 + xm = xf / x_f_fullscale + Δp1 = p_s - p1 + Δp2 = p2 + + eqs = [+flow(xm, Δp1) ~ rho1 * dV1 + drho1 * V1 + 0 ~ IfElse.ifelse(w > 0.5, + (0) - (rho2 * dV2 + drho2 * V2), + (-flow(xm, Δp2)) - (rho2 * dV2 + drho2 * V2)) + V1 ~ (l_1f + w) * A_1f + V2 ~ (l_2f - w) * A_2f + dV1 ~ +dw * A_1f + dV2 ~ -dw * A_2f + rho1 ~ density * (1.0 + p1 / bulk) + rho2 ~ density * (1.0 + p2 / bulk) + drho1 ~ density * (dp1 / bulk) + drho2 ~ density * (dp2 / bulk) + D(p1) ~ dp1 + D(p2) ~ dp2 + D(w) ~ dw + D(dw) ~ ddw + xf ~ 20e-3 * (1 - cos(2 * π * 5 * t)) + 0 ~ IfElse.ifelse(w > 0.5, + (m_total * ddw) - (p1 * A_1f - p2 * A_2f - damp * dw), + (m_total * ddw) - (p1 * A_1f - p2 * A_2f))] + # ---------------------------------------------------------------------------- + + # solution ------------------------------------------------------------------- + @named catapult = ODESystem(eqs, t, vars, params, defaults = defs) + sys = structural_simplify(catapult) + prob = ODEProblem(sys, [], (0.0, 0.1), [l_2f => 0.55, damp => 1e7]; jac = true) + @test solve(prob, Rodas4()).retcode == :Success end diff --git a/test/structural_transformation/index_reduction.jl b/test/structural_transformation/index_reduction.jl index 5ad1ef3c97..e9011cc9a7 100644 --- a/test/structural_transformation/index_reduction.jl +++ b/test/structural_transformation/index_reduction.jl @@ -136,22 +136,28 @@ let pss_pendulum = partial_state_selection(pendulum) @test_broken length(equations(pss_pendulum)) == 3 end -sys = structural_simplify(pendulum2) -@test length(equations(sys)) == 5 -@test length(states(sys)) == 5 - -u0 = [ - D(x) => 0.0, - D(y) => 0.0, - x => sqrt(2) / 2, - y => sqrt(2) / 2, - T => 0.0, +for sys in [ + structural_simplify(pendulum2), + structural_simplify(ode_order_lowering(pendulum2)), ] -p = [ - L => 1.0, - g => 9.8, -] - -prob_auto = DAEProblem(sys, zeros(length(u0)), u0, (0.0, 0.2), p) -sol = solve(prob_auto, DFBDF()) -@test norm(sol[x] .^ 2 + sol[y] .^ 2 .- 1) < 1e-2 + @test length(equations(sys)) == 5 + @test length(states(sys)) == 5 + + u0 = [ + D(x) => 0.0, + D(D(x)) => 0.0, + D(y) => 0.0, + D(D(y)) => 0.0, + x => sqrt(2) / 2, + y => sqrt(2) / 2, + T => 0.0, + ] + p = [ + L => 1.0, + g => 9.8, + ] + + prob_auto = ODEProblem(sys, u0, (0.0, 1.0), p) + sol = solve(prob_auto, FBDF()) + @test norm(sol[x] .^ 2 + sol[y] .^ 2 .- 1) < 1e-2 +end