Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ end
function Matching(m::Int)
Matching{Unassigned}(Union{Int, Unassigned}[unassigned for _ in 1:m], nothing)
end
function Matching{U}(m::Int) where {U}
Matching{Union{Unassigned, U}}(Union{Int, Unassigned, U}[unassigned for _ in 1:m],
nothing)
end

Base.size(m::Matching) = Base.size(m.match)
Base.getindex(m::Matching, i::Integer) = m.match[i]
Expand All @@ -65,9 +69,9 @@ function Base.setindex!(m::Matching{U}, v::Union{Integer, U}, i::Integer) where
return m.match[i] = v
end

function Base.push!(m::Matching{U}, v::Union{Integer, U}) where {U}
function Base.push!(m::Matching, v)
push!(m.match, v)
if v !== unassigned && m.inv_match !== nothing
if v isa Integer && m.inv_match !== nothing
m.inv_match[v] = length(m.match)
end
end
Expand Down Expand Up @@ -299,8 +303,8 @@ vertices, subject to the constraint that vertices for which `srcfilter` or `dstf
return `false` may not be matched.
"""
function maximal_matching(g::BipartiteGraph, srcfilter = vsrc -> true,
dstfilter = vdst -> true)
matching = Matching(ndsts(g))
dstfilter = vdst -> true, ::Type{U} = Unassigned) where {U}
matching = Matching{U}(ndsts(g))
foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc
construct_augmenting_path!(matching, g, vsrc, dstfilter)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, va
return nothing
end

function tear_graph_modia(structure::SystemStructure; varfilter = v -> true,
eqfilter = eq -> true)
function tear_graph_modia(structure::SystemStructure, ::Type{U} = Unassigned;
varfilter = v -> true, eqfilter = eq -> true) where {U}
@unpack graph, solvable_graph = structure
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter))
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter, U))
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)

for vars in var_sccs
Expand Down
86 changes: 56 additions & 30 deletions src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,53 +146,48 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
var_eq_matching
end

function dummy_derivative_graph!(state::TransformationState, jac = nothing)
function dummy_derivative_graph!(state::TransformationState, jac = nothing; kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
var_eq_matching = complete(pantelides!(state))
complete!(state.structure)
dummy_derivative_graph!(state.structure, var_eq_matching, jac)
end

function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac)
@unpack eq_to_diff, var_to_diff, graph = structure
diff_to_eq = invview(eq_to_diff)
diff_to_var = invview(var_to_diff)
invgraph = invview(graph)

neqs = nsrcs(graph)
eqlevel = zeros(Int, neqs)
function compute_diff_level(diff_to_x)
nxs = length(diff_to_x)
xlevel = zeros(Int, nxs)
maxlevel = 0
for i in 1:neqs
for i in 1:nxs
level = 0
eq = i
while diff_to_eq[eq] !== nothing
eq = diff_to_eq[eq]
x = i
while diff_to_x[x] !== nothing
x = diff_to_x[x]
level += 1
end
maxlevel = max(maxlevel, level)
eqlevel[i] = level
xlevel[i] = level
end
return xlevel, maxlevel
end

nvars = ndsts(graph)
varlevel = zeros(Int, nvars)
for i in 1:nvars
level = 0
var = i
while diff_to_var[var] !== nothing
var = diff_to_var[var]
level += 1
end
maxlevel = max(maxlevel, level)
varlevel[i] = level
end
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac)
@unpack eq_to_diff, var_to_diff, graph = structure
diff_to_eq = invview(eq_to_diff)
diff_to_var = invview(var_to_diff)
invgraph = invview(graph)

eqlevel, _ = compute_diff_level(diff_to_eq)
varlevel, _ = compute_diff_level(diff_to_var)

var_sccs = find_var_sccs(graph, var_eq_matching)
eqcolor = falses(neqs)
eqcolor = falses(nsrcs(graph))
dummy_derivatives = Int[]
col_order = Int[]
nvars = ndsts(graph)
for vars in var_sccs
eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] !== unassigned]
isempty(eqs) && continue
maxlevel = maximum(map(x -> eqlevel[x], eqs))
maxlevel = maximum(Base.Fix1(getindex, eqlevel), eqs)
iszero(maxlevel) && continue

rank_matching = Matching(nvars)
Expand Down Expand Up @@ -220,8 +215,10 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
else
rank = 0
for var in vars
# We need `invgraph` here because we are matching from
# variables to equations.
pathfound = construct_augmenting_path!(rank_matching, invgraph, var,
eq -> eq in eqs_set, eqcolor)
Base.Fix2(in, eqs_set), eqcolor)
pathfound || continue
push!(dummy_derivatives, var)
rank += 1
Expand All @@ -239,5 +236,34 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
end
end

dummy_derivatives
dummy_derivatives_set = BitSet(dummy_derivatives)
# We can eliminate variables that are not a selected state (differential
# variables). Selected states are differentiated variables that are not
# dummy derivatives.
can_eliminate = let var_to_diff = var_to_diff, dummy_derivatives_set = dummy_derivatives_set

v -> begin
dv = var_to_diff[v]
dv === nothing || dv in dummy_derivatives_set
end
end

# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with
# actually differentiated variables.
isdiffed = let diff_to_var = diff_to_var, dummy_derivatives_set = dummy_derivatives_set
v -> diff_to_var[v] !== nothing && !(v in dummy_derivatives_set)
end
should_consider = let graph = graph, isdiffed = isdiffed
eq -> !any(isdiffed, 𝑠neighbors(graph, eq))
end

var_eq_matching = tear_graph_modia(structure, Union{Unassigned, SelectedState};
varfilter = can_eliminate,
eqfilter = should_consider)
for v in eachindex(var_eq_matching)
can_eliminate(v) && continue
var_eq_matching[v] = SelectedState()
end

return var_eq_matching
end
Loading