Skip to content

Commit

Permalink
Merge pull request #2805 from SciML/myb/conservative
Browse files Browse the repository at this point in the history
Add `conservative` kwarg in `structural_transformation`
  • Loading branch information
ChrisRackauckas committed Jun 15, 2024
2 parents 0cc954d + 898e592 commit 22fb377
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ end
Perform Pantelides algorithm.
"""
function pantelides!(state::TransformationState; finalize = true, maxiters = 8000)
function pantelides!(
state::TransformationState; finalize = true, maxiters = 8000, kwargs...)
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
neqs = nsrcs(graph)
nvars = nv(var_to_diff)
Expand Down Expand Up @@ -181,7 +182,7 @@ function pantelides!(state::TransformationState; finalize = true, maxiters = 800
ecolor[eq] || continue
# introduce a new equation
neqs += 1
eq_derivative!(state, eq)
eq_derivative!(state, eq; kwargs...)
end

for var in eachindex(vcolor)
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function dummy_derivative_graph!(state::TransformationState, jac = nothing;
state_priority = nothing, log = Val(false), kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
complete!(state.structure)
var_eq_matching = complete(pantelides!(state))
var_eq_matching = complete(pantelides!(state; kwargs...))
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log)
end

Expand Down
5 changes: 3 additions & 2 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function eq_derivative_graph!(s::SystemStructure, eq::Int)
return eq_diff
end

function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)
s = ts.structure

eq_diff = eq_derivative_graph!(s, ieq)
Expand All @@ -75,7 +75,8 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
add_edge!(s.graph, eq_diff, s.var_to_diff[var])
end
s.solvable_graph === nothing ||
find_eq_solvables!(ts, eq_diff; may_be_zero = true, allow_symbolic = false)
find_eq_solvables!(
ts, eq_diff; may_be_zero = true, allow_symbolic = false, kwargs...)

return eq_diff
end
Expand Down
5 changes: 4 additions & 1 deletion src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ end

function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = nothing;
may_be_zero = false,
allow_symbolic = false, allow_parameter = true, kwargs...)
allow_symbolic = false, allow_parameter = true,
conservative = false,
kwargs...)
fullvars = state.fullvars
@unpack graph, solvable_graph = state.structure
eq = equations(state)[ieq]
Expand Down Expand Up @@ -220,6 +222,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
coeffs === nothing || push!(coeffs, convert(Int, a))
else
all_int_vars = false
conservative && continue
end
if a != 0
add_edge!(solvable_graph, ieq, j)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ $(SIGNATURES)
Structurally simplify algebraic equations in a system and compute the
topological sort of the observed equations. When `simplify=true`, the `simplify`
function will be applied during the tearing process. It also takes kwargs
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
types during tearing.
`allow_symbolic=false`, `allow_parameter=true`, and `conservative=false` which
limits the coefficient types during tearing. In particular, `conservative=true`
limits tearing to only solve for trivial linear systems where the coefficient
has the absolute value of ``1``.
The optional argument `io` may take a tuple `(inputs, outputs)`.
This will convert all `inputs` to parameters and allow them to be unconnected, i.e.,
Expand Down
9 changes: 6 additions & 3 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,15 +691,18 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
ModelingToolkit.check_consistency(state, orig_inputs)
end
if fully_determined && dummy_derivative
sys = ModelingToolkit.dummy_derivative(sys, state; simplify, mm, check_consistency)
sys = ModelingToolkit.dummy_derivative(
sys, state; simplify, mm, check_consistency, kwargs...)
elseif fully_determined
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
sys = pantelides_reassemble(state, var_eq_matching)
state = TearingState(sys)
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
sys = ModelingToolkit.dummy_derivative(sys, state; simplify, mm, check_consistency)
sys = ModelingToolkit.dummy_derivative(
sys, state; simplify, mm, check_consistency, kwargs...)
else
sys = ModelingToolkit.tearing(sys, state; simplify, mm, check_consistency)
sys = ModelingToolkit.tearing(
sys, state; simplify, mm, check_consistency, kwargs...)
end
fullunknowns = [map(eq -> eq.lhs, observed(sys)); unknowns(sys)]
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns)
Expand Down
9 changes: 9 additions & 0 deletions test/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,12 @@ eqs = [u3 ~ u1 + u2, u4 ~ 2 * (u1 + u2), u3 + u4 ~ 3 * (u1 + u2)]
@named ns = NonlinearSystem(eqs, [u1, u2], [u3, u4])
sys = structural_simplify(ns; fully_determined = false)
@test length(unknowns(sys)) == 1

# Conservative
@variables X(t)
alg_eqs = [1 ~ 2X]
@named ns = NonlinearSystem(alg_eqs)
sys = structural_simplify(ns)
@test length(equations(sys)) == 0
sys = structural_simplify(ns; conservative = true)
@test length(equations(sys)) == 1

0 comments on commit 22fb377

Please sign in to comment.