Skip to content

Commit

Permalink
Merge f934c96 into 3ee1f87
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Oct 19, 2022
2 parents 3ee1f87 + f934c96 commit c6c61ee
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 36 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down Expand Up @@ -73,7 +72,6 @@ Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.4.3, 0.5"
SciMLBase = "1.58.0"
Setfield = "0.7, 0.8, 1"
SimpleWeightedGraphs = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicUtils = "0.19"
Expand Down
10 changes: 5 additions & 5 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,17 @@ function same_or_inner_namespace(u, var)
nv = get_namespace(var)
nu == nv || # namespaces are the same
startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namepsace to nu
occursin('', string(Symbolics.getname(var))) &&
!occursin('', string(Symbolics.getname(u))) # or u is top level but var is internal
occursin('', string(getname(var))) &&
!occursin('', string(getname(u))) # or u is top level but var is internal
end

function inner_namespace(u, var)
nu = get_namespace(u)
nv = get_namespace(var)
nu == nv && return false
startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namepsace to nu
occursin('', string(Symbolics.getname(var))) &&
!occursin('', string(Symbolics.getname(u))) # or u is top level but var is internal
occursin('', string(getname(var))) &&
!occursin('', string(getname(u))) # or u is top level but var is internal
end

"""
Expand All @@ -138,7 +138,7 @@ end
Return the namespace of a variable as a string. If the variable is not namespaced, the string is empty.
"""
function get_namespace(x)
sname = string(Symbolics.getname(x))
sname = string(getname(x))
parts = split(sname, '')
if length(parts) == 1
return ""
Expand Down
3 changes: 2 additions & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
get_postprocess_fbody, vars!,
IncrementalCycleTracker, add_edge_checked!, topological_sort,
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
fast_substitute

using ModelingToolkit.BipartiteGraphs
import .BipartiteGraphs: invview, complete
Expand Down
8 changes: 4 additions & 4 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
idx_buffer = Int[]
sub_callback! = let eqs = neweqs, fullvars = fullvars
(ieq, s) -> begin
neweq = substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
neweq = fast_substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
eqs[ieq] = neweq
end
end
Expand Down Expand Up @@ -282,7 +282,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
end
for eq in 𝑑neighbors(graph, dv)
dummy_sub[dd] = v_t
neweqs[eq] = substitute(neweqs[eq], dd => v_t)
neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t)
end
fullvars[dv] = v_t
# If we have:
Expand All @@ -295,7 +295,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
while (ddx = var_to_diff[dx]) !== nothing
dx_t = D(x_t)
for eq in 𝑑neighbors(graph, ddx)
neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t)
neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t)
end
fullvars[ddx] = dx_t
dx = ddx
Expand Down Expand Up @@ -655,7 +655,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
obs_sub[eq.lhs] = eq.rhs
end
# TODO: compute the dependency correctly so that we don't have to do this
obs = substitute.([oldobs; subeqs], (obs_sub,))
obs = fast_substitute([oldobs; subeqs], obs_sub)
@set! sys.observed = obs
@set! state.sys = sys
@set! sys.tearing_state = state
Expand Down
83 changes: 59 additions & 24 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using SymbolicUtils: Rewriters
using SimpleWeightedGraphs
using Graphs.Experimental.Traversals

const KEEP = typemin(Int)
Expand Down Expand Up @@ -153,7 +152,8 @@ function alias_elimination!(state::TearingState; kwargs...)
end
end
for ieq in eqs_to_update
eqs[ieq] = substitute(eqs[ieq], subs)
eq = eqs[ieq]
eqs[ieq] = fast_substitute(eq, subs)
end

for old_ieq in to_expand
Expand Down Expand Up @@ -365,9 +365,33 @@ function Base.in(i::Int, agk::AliasGraphKeySet)
1 <= i <= length(aliasto) && aliasto[i] !== nothing
end

canonicalize(a, b) = a <= b ? (a, b) : (b, a)
struct WeightedGraph{T, W} <: AbstractGraph{T}
graph::SimpleGraph{T}
dict::Dict{Tuple{T, T}, W}
end
function WeightedGraph{T, W}(n) where {T, W}
WeightedGraph{T, W}(SimpleGraph{T}(n), Dict{Tuple{T, T}, W}())
end

function Graphs.add_edge!(g::WeightedGraph, u, v, w)
r = add_edge!(g.graph, u, v)
r && (g.dict[canonicalize(u, v)] = w)
r
end
Graphs.has_edge(g::WeightedGraph, u, v) = has_edge(g.graph, u, v)
Graphs.ne(g::WeightedGraph) = ne(g.graph)
Graphs.nv(g::WeightedGraph) = nv(g.graph)
get_weight(g::WeightedGraph, u, v) = g.dict[canonicalize(u, v)]
Graphs.is_directed(::Type{<:WeightedGraph}) = false
Graphs.inneighbors(g::WeightedGraph, v) = inneighbors(g.graph, v)
Graphs.outneighbors(g::WeightedGraph, v) = outneighbors(g.graph, v)
Graphs.vertices(g::WeightedGraph) = vertices(g.graph)
Graphs.edges(g::WeightedGraph) = vertices(g.graph)

function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph)
g = SimpleDiGraph{Int}(length(var_to_diff))
eqg = SimpleWeightedGraph{Int, Int}(length(var_to_diff))
eqg = WeightedGraph{Int, Int}(length(var_to_diff))
zero_vars = Int[]
for (v, (c, a)) in ag
if iszero(a)
Expand All @@ -378,7 +402,6 @@ function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph)
add_edge!(g, a, v)

add_edge!(eqg, v, a, c)
add_edge!(eqg, a, v, c)
end
transitiveclosure!(g)
weighted_transitiveclosure!(eqg)
Expand All @@ -394,9 +417,14 @@ end
function weighted_transitiveclosure!(g)
cps = connected_components(g)
for cp in cps
for k in cp, i in cp, j in cp
(has_edge(g, i, k) && has_edge(g, k, j)) || continue
add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j))
n = length(cp)
for k in cp
for i′ in 1:n, j′ in (i′ + 1):n
i = cp[i′]
j = cp[j′]
(has_edge(g, i, k) && has_edge(g, k, j)) || continue
add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j))
end
end
end
return g
Expand Down Expand Up @@ -670,11 +698,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
end
diff_aliases = Vector{Pair{Int, Int}}[]
stem = Int[]
stems = Vector{Int}[]
stem_set = BitSet()
for (v, dv) in enumerate(var_to_diff)
processed[v] && continue
(dv === nothing && diff_to_var[v] === nothing) && continue
stem = Int[]
r = find_root!(dls, g, v)
prev_r = -1
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
Expand Down Expand Up @@ -714,9 +743,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
push!(stem_set, prev_r)
push!(stem, prev_r)
push!(diff_aliases, reach₌)
for (_, v) in reach₌
for (c, v) in reach₌
v == prev_r && continue
add_edge!(eqg, v, prev_r)
add_edge!(eqg, v, prev_r, c)
end
end

Expand All @@ -729,9 +758,24 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
dag[v] = c => a
end
end
# Obtain transitive closure after completing the alias edges from diff
# edges.
weighted_transitiveclosure!(eqg)
push!(stems, stem)

# clean up
for v in dls.visited
dls.dists[v] = typemax(Int)
processed[v] = true
end
empty!(dls.visited)
empty!(diff_aliases)
empty!(stem_set)
end

# Obtain transitive closure after completing the alias edges from diff
# edges. As a performance optimization, we only compute the transitive
# closure once at the very end.
weighted_transitiveclosure!(eqg)
zero_vars_set = BitSet()
for stem in stems
# Canonicalize by preferring the lower differentiated variable
# If we have the system
# ```
Expand Down Expand Up @@ -780,7 +824,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
# x := 0
# y := 0
# ```
zero_vars_set = BitSet()
for v in zero_vars
for a in Iterators.flatten((v, outneighbors(eqg, v)))
while true
Expand All @@ -803,17 +846,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
dag[v] = 0
end
end

# clean up
for v in dls.visited
dls.dists[v] = typemax(Int)
processed[v] = true
end
empty!(dls.visited)
empty!(diff_aliases)
empty!(stem)
empty!(stem_set)
empty!(zero_vars_set)
end

# update `dag`
for k in keys(dag)
dag[k]
Expand Down
33 changes: 33 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,36 @@ function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
end

# Symbolics needs to call unwrap on the substitution rules, but most of the time
# we don't want to do that in MTK.
function fast_substitute(eq::Equation, subs)
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
end
function fast_substitute(eq::Equation, subs::Pair)
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
end
fast_substitute(eqs::AbstractArray{Equation}, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(a, b) = substitute(a, b)
function fast_substitute(expr, pair::Pair)
a, b = pair
isequal(expr, a) && return b

istree(expr) || return expr
op = fast_substitute(operation(expr), pair)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(SymbolicUtils.unsorted_arguments(expr)) do x
x′ = fast_substitute(x, pair)
canfold[] = canfold[] && !(x′ isa Symbolic)
x′
end
end
canfold[] && return op(args...)

similarterm(expr,
op,
args,
symtype(expr);
metadata = metadata(expr))
end

0 comments on commit c6c61ee

Please sign in to comment.