From 63eb6745dd34c2940a6b97b2fcd641681bc43d70 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 6 Jan 2023 07:32:43 -0500 Subject: [PATCH] Better extra variable reporting and reverse the arrow in `compute_diff_label` --- src/structural_transformation/utils.jl | 17 +++++++++++++---- src/systems/abstractsystem.jl | 7 +++++-- src/systems/systemstructure.jl | 15 +++++++++++---- test/input_output_handling.jl | 13 ++++++++++++- 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 5d49e85f61..5b2ae6c3d4 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -12,19 +12,28 @@ function BipartiteGraphs.maximal_matching(s::SystemStructure, eqfilter = eq -> t maximal_matching(s.graph, eqfilter, varfilter) end -function error_reporting(state, bad_idxs, n_highest_vars, iseqs) +function error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs) io = IOBuffer() + neqs = length(equations(state)) if iseqs error_title = "More equations than variables, here are the potential extra equation(s):\n" out_arr = equations(state)[bad_idxs] else error_title = "More variables than equations, here are the potential extra variable(s):\n" out_arr = state.fullvars[bad_idxs] + unset_inputs = intersect(out_arr, orig_inputs) + n_missing_eqs = n_highest_vars - neqs + n_unset_inputs = length(unset_inputs) + if n_unset_inputs > 0 + println(io, "In particular, the unset input(s) are:") + Base.print_array(io, unset_inputs) + println(io) + println(io, "The rest of potentially unset variable(s) are:") + end end Base.print_array(io, out_arr) msg = String(take!(io)) - neqs = length(equations(state)) if iseqs throw(ExtraEquationsSystemException("The system is unbalanced. There are " * "$n_highest_vars highest order derivative variables " @@ -43,7 +52,7 @@ end ### ### Structural check ### -function check_consistency(state::TearingState, ag = nothing) +function check_consistency(state::TearingState, ag, orig_inputs) fullvars = state.fullvars @unpack graph, var_to_diff = state.structure n_highest_vars = count(v -> var_to_diff[v] === nothing && @@ -64,7 +73,7 @@ function check_consistency(state::TearingState, ag = nothing) else bad_idxs = findall(isequal(unassigned), var_eq_matching) end - error_reporting(state, bad_idxs, n_highest_vars, iseqs) + error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs) end # This is defined to check if Pantelides algorithm terminates. For more diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 181484be92..5426a383c9 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1193,7 +1193,7 @@ function linearization_function(sys::AbstractSystem, inputs, return lin_fun, sys end -function markio!(state, inputs, outputs; check = true) +function markio!(state, orig_inputs, inputs, outputs; check = true) fullvars = state.fullvars inputset = Dict{Any, Bool}(i => false for i in inputs) outputset = Dict{Any, Bool}(o => false for o in outputs) @@ -1207,6 +1207,9 @@ function markio!(state, inputs, outputs; check = true) outputset[v] = true fullvars[i] = v else + if isinput(v) + push!(orig_inputs, v) + end v = setio(v, false, false) fullvars[i] = v end @@ -1221,7 +1224,7 @@ function markio!(state, inputs, outputs; check = true) check && (all(values(outputset)) || error("Some specified outputs were not found in system. The following Dict indicates the found variables ", outputset)) - state + state, orig_inputs end """ diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 262dc7f1ee..4296028f86 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -200,6 +200,10 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} extra_eqs::Vector end +function Base.show(io::IO, state::TearingState) + print(io, "TearingState of ", typeof(state.sys)) +end + struct EquationsView{T} <: AbstractVector{Any} ts::TearingState{T} end @@ -386,9 +390,9 @@ Base.size(bgpm::SystemStructurePrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.b function compute_diff_label(diff_graph, i) di = i - 1 <= length(diff_graph) ? diff_graph[i - 1] : nothing ii = i - 1 <= length(invview(diff_graph)) ? invview(diff_graph)[i - 1] : nothing - return Label(string(di === nothing ? "" : string(di, '↓'), + return Label(string(di === nothing ? "" : string(di, '↑'), di !== nothing && ii !== nothing ? " " : "", - ii === nothing ? "" : string(ii, '↑'))) + ii === nothing ? "" : string(ii, '↓'))) end function Base.getindex(bgpm::SystemStructurePrintMatrix, i::Integer, j::Integer) checkbounds(bgpm, i, j) @@ -519,11 +523,14 @@ end function _structural_simplify!(state::TearingState, io; simplify = false, check_consistency = true, kwargs...) has_io = io !== nothing - has_io && ModelingToolkit.markio!(state, io...) + orig_inputs = Set() + if has_io + ModelingToolkit.markio!(state, orig_inputs, io...) + end state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io) sys, ag = ModelingToolkit.alias_elimination!(state; kwargs...) if check_consistency - ModelingToolkit.check_consistency(state, ag) + ModelingToolkit.check_consistency(state, ag, orig_inputs) end sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify) fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)] diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index c146e524ab..7c034dd74d 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -1,6 +1,17 @@ using ModelingToolkit, Symbolics, Test using ModelingToolkit: get_namespace, has_var, inputs, outputs, is_bound, bound_inputs, - unbound_inputs, bound_outputs, unbound_outputs, isinput, isoutput + unbound_inputs, bound_outputs, unbound_outputs, isinput, isoutput, + ExtraVariablesSystemException + +@variables t xx(t) some_input(t) [input = true] +D = Differential(t) +eqs = [D(xx) ~ some_input] +@named model = ODESystem(eqs, t) +@test_throws ExtraVariablesSystemException structural_simplify(model, ((), ())) +if VERSION >= v"1.8" + err = "In particular, the unset input(s) are:\n some_input(t)" + @test_throws err structural_simplify(model, ((), ())) +end # Test input handling @parameters tv