From b6de5a124b7638a3c6b4ecb15a8f65925e05d117 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 18 Oct 2022 18:55:59 -0400 Subject: [PATCH 1/6] Use faster data structure for the transitive closure calculation Before: ```julia julia> @time alias_elimination(sysEx); 28.781313 seconds (51.53 M allocations: 3.998 GiB, 2.81% gc time) ``` After: ```julia julia> @time alias_elimination(sysEx); 18.543368 seconds (54.64 M allocations: 4.446 GiB, 4.13% gc time) ``` --- Project.toml | 2 -- src/systems/alias_elimination.jl | 43 ++++++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 017e374e3d..66783ae237 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,6 @@ RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -73,7 +72,6 @@ Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.4.3, 0.5" SciMLBase = "1.58.0" Setfield = "0.7, 0.8, 1" -SimpleWeightedGraphs = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" SymbolicUtils = "0.19" diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index f20e4bf572..90b3dc3837 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -1,5 +1,4 @@ using SymbolicUtils: Rewriters -using SimpleWeightedGraphs using Graphs.Experimental.Traversals const KEEP = typemin(Int) @@ -365,9 +364,33 @@ function Base.in(i::Int, agk::AliasGraphKeySet) 1 <= i <= length(aliasto) && aliasto[i] !== nothing end +canonicalize(a, b) = a <= b ? (a, b) : (b, a) +struct WeightedGraph{T, W} <: AbstractGraph{Int64} + graph::SimpleGraph{T} + dict::Dict{Tuple{T, T}, W} +end +function WeightedGraph{T, W}(n) where {T, W} + WeightedGraph{T, W}(SimpleGraph{T}(n), Dict{Tuple{T, T}, W}()) +end + +function Graphs.add_edge!(g::WeightedGraph, u, v, w) + r = add_edge!(g.graph, u, v) + r && (g.dict[canonicalize(u, v)] = w) + r +end +Graphs.has_edge(g::WeightedGraph, u, v) = has_edge(g.graph, u, v) +Graphs.ne(g::WeightedGraph) = ne(g.graph) +Graphs.nv(g::WeightedGraph) = nv(g.graph) +get_weight(g::WeightedGraph, u, v) = g.dict[canonicalize(u, v)] +Graphs.is_directed(::Type{<:WeightedGraph}) = false +Graphs.inneighbors(g::WeightedGraph, v) = inneighbors(g.graph, v) +Graphs.outneighbors(g::WeightedGraph, v) = outneighbors(g.graph, v) +Graphs.vertices(g::WeightedGraph) = vertices(g.graph) +Graphs.edges(g::WeightedGraph) = vertices(g.graph) + function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph) g = SimpleDiGraph{Int}(length(var_to_diff)) - eqg = SimpleWeightedGraph{Int, Int}(length(var_to_diff)) + eqg = WeightedGraph{Int, Int}(length(var_to_diff)) zero_vars = Int[] for (v, (c, a)) in ag if iszero(a) @@ -378,7 +401,6 @@ function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph) add_edge!(g, a, v) add_edge!(eqg, v, a, c) - add_edge!(eqg, a, v, c) end transitiveclosure!(g) weighted_transitiveclosure!(eqg) @@ -394,9 +416,14 @@ end function weighted_transitiveclosure!(g) cps = connected_components(g) for cp in cps - for k in cp, i in cp, j in cp - (has_edge(g, i, k) && has_edge(g, k, j)) || continue - add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j)) + n = length(cp) + for k in cp + for i′ in 1:n, j′ in (i′ + 1):n + i = cp[i′] + j = cp[j′] + (has_edge(g, i, k) && has_edge(g, k, j)) || continue + add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j)) + end end end return g @@ -714,9 +741,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) push!(stem_set, prev_r) push!(stem, prev_r) push!(diff_aliases, reach₌) - for (_, v) in reach₌ + for (c, v) in reach₌ v == prev_r && continue - add_edge!(eqg, v, prev_r) + add_edge!(eqg, v, prev_r, c) end end From db05a0348f8fd68606684844525f0c96af785431 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 18 Oct 2022 19:13:08 -0400 Subject: [PATCH 2/6] Only compute `weighted_transitiveclosure!` twice --- src/systems/alias_elimination.jl | 37 +++++++++++++++++++------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 90b3dc3837..0c20c88dfc 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -697,11 +697,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) (v, w) -> var_to_diff[v] == w || var_to_diff[w] == v end diff_aliases = Vector{Pair{Int, Int}}[] - stem = Int[] + stems = Vector{Int}[] stem_set = BitSet() for (v, dv) in enumerate(var_to_diff) processed[v] && continue (dv === nothing && diff_to_var[v] === nothing) && continue + stem = Int[] r = find_root!(dls, g, v) prev_r = -1 for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop @@ -756,9 +757,24 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) dag[v] = c => a end end - # Obtain transitive closure after completing the alias edges from diff - # edges. - weighted_transitiveclosure!(eqg) + push!(stems, stem) + + # clean up + for v in dls.visited + dls.dists[v] = typemax(Int) + processed[v] = true + end + empty!(dls.visited) + empty!(diff_aliases) + empty!(stem_set) + end + + # Obtain transitive closure after completing the alias edges from diff + # edges. As a performance optimization, we only compute the transitive + # closure once at the very end. + weighted_transitiveclosure!(eqg) + zero_vars_set = BitSet() + for stem in stems # Canonicalize by preferring the lower differentiated variable # If we have the system # ``` @@ -807,7 +823,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) # x := 0 # y := 0 # ``` - zero_vars_set = BitSet() for v in zero_vars for a in Iterators.flatten((v, outneighbors(eqg, v))) while true @@ -830,17 +845,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) dag[v] = 0 end end - - # clean up - for v in dls.visited - dls.dists[v] = typemax(Int) - processed[v] = true - end - empty!(dls.visited) - empty!(diff_aliases) - empty!(stem) - empty!(stem_set) + empty!(zero_vars_set) end + # update `dag` for k in keys(dag) dag[k] From 022008b4f051d22e4621d9be48868517557621cc Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 18 Oct 2022 19:27:40 -0400 Subject: [PATCH 3/6] Use a lower level API to get more speed up ```julia julia> @time alias_elimination(sysEx); 1.130002 seconds (6.42 M allocations: 346.369 MiB, 12.44% gc time) ``` --- src/systems/alias_elimination.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 0c20c88dfc..4648f42c10 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -151,8 +151,10 @@ function alias_elimination!(state::TearingState; kwargs...) k === nothing && break end end + subfun = Base.Fix2(substitute, subs) for ieq in eqs_to_update - eqs[ieq] = substitute(eqs[ieq], subs) + eq = eqs[ieq] + eqs[ieq] = subfun(eq.lhs) ~ subfun(eq.rhs) end for old_ieq in to_expand From c049454c98be8c52abedf8a7ac91599d28194bbf Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 18 Oct 2022 20:01:53 -0400 Subject: [PATCH 4/6] Add `fast_substitute` Before: ```julia julia> @time sysRed = tearing(sysEx); 8.903631 seconds (42.38 M allocations: 2.968 GiB, 8.06% gc time) ``` After: ```julia julia> @time tearing(sysEx); 1.733097 seconds (10.90 M allocations: 1.059 GiB, 19.44% gc time) ``` --- .../StructuralTransformations.jl | 3 +- .../symbolics_tearing.jl | 8 ++--- src/systems/alias_elimination.jl | 3 +- src/utils.jl | 33 +++++++++++++++++++ 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index ec263ca92c..eb3a35090b 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -22,7 +22,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di get_postprocess_fbody, vars!, IncrementalCycleTracker, add_edge_checked!, topological_sort, invalidate_cache!, Substitutions, get_or_construct_tearing_state, - AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL + AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL, + fast_substitute using ModelingToolkit.BipartiteGraphs import .BipartiteGraphs: invview, complete diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 9e96ccdd71..9b49cb0db7 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -227,7 +227,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing; idx_buffer = Int[] sub_callback! = let eqs = neweqs, fullvars = fullvars (ieq, s) -> begin - neweq = substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]]) + neweq = fast_substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]]) eqs[ieq] = neweq end end @@ -282,7 +282,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing; end for eq in 𝑑neighbors(graph, dv) dummy_sub[dd] = v_t - neweqs[eq] = substitute(neweqs[eq], dd => v_t) + neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t) end fullvars[dv] = v_t # If we have: @@ -295,7 +295,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing; while (ddx = var_to_diff[dx]) !== nothing dx_t = D(x_t) for eq in 𝑑neighbors(graph, ddx) - neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t) + neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t) end fullvars[ddx] = dx_t dx = ddx @@ -655,7 +655,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing; obs_sub[eq.lhs] = eq.rhs end # TODO: compute the dependency correctly so that we don't have to do this - obs = substitute.([oldobs; subeqs], (obs_sub,)) + obs = fast_substitute([oldobs; subeqs], obs_sub) @set! sys.observed = obs @set! state.sys = sys @set! sys.tearing_state = state diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 4648f42c10..59a3945607 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -151,10 +151,9 @@ function alias_elimination!(state::TearingState; kwargs...) k === nothing && break end end - subfun = Base.Fix2(substitute, subs) for ieq in eqs_to_update eq = eqs[ieq] - eqs[ieq] = subfun(eq.lhs) ~ subfun(eq.rhs) + eqs[ieq] = fast_substitute(eq, subs) end for old_ieq in to_expand diff --git a/src/utils.jl b/src/utils.jl index 70c6eecfba..19580a0189 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -741,3 +741,36 @@ function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C} cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag) ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false)) end + +# Symbolics needs to call unwrap on the substitution rules, but most of the time +# we don't want to do that in MTK. +function fast_substitute(eq::Equation, subs) + fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs) +end +function fast_substitute(eq::Equation, subs::Pair) + fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs) +end +fast_substitute(eqs::AbstractArray{Equation}, subs) = fast_substitute.(eqs, (subs,)) +fast_substitute(a, b) = substitute(a, b) +function fast_substitute(expr, pair::Pair) + a, b = pair + isequal(expr, a) && return b + + istree(expr) || return expr + op = fast_substitute(operation(expr), pair) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(SymbolicUtils.unsorted_arguments(expr)) do x + x′ = fast_substitute(x, pair) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end + end + canfold[] && return op(args...) + + similarterm(expr, + op, + args, + symtype(expr); + metadata = metadata(expr)) +end From d5983bb10149752e40f3a6cd0810d7f29c6e9047 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 18 Oct 2022 20:02:17 -0400 Subject: [PATCH 5/6] Remove unnecessary `Symbolics.` --- src/inputoutput.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 8fd1ed3022..859dc4a2b9 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -119,8 +119,8 @@ function same_or_inner_namespace(u, var) nv = get_namespace(var) nu == nv || # namespaces are the same startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namepsace to nu - occursin('₊', string(Symbolics.getname(var))) && - !occursin('₊', string(Symbolics.getname(u))) # or u is top level but var is internal + occursin('₊', string(getname(var))) && + !occursin('₊', string(getname(u))) # or u is top level but var is internal end function inner_namespace(u, var) @@ -128,8 +128,8 @@ function inner_namespace(u, var) nv = get_namespace(var) nu == nv && return false startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namepsace to nu - occursin('₊', string(Symbolics.getname(var))) && - !occursin('₊', string(Symbolics.getname(u))) # or u is top level but var is internal + occursin('₊', string(getname(var))) && + !occursin('₊', string(getname(u))) # or u is top level but var is internal end """ @@ -138,7 +138,7 @@ end Return the namespace of a variable as a string. If the variable is not namespaced, the string is empty. """ function get_namespace(x) - sname = string(Symbolics.getname(x)) + sname = string(getname(x)) parts = split(sname, '₊') if length(parts) == 1 return "" From f934c96ce65fbf05f07a9023b3efa1c56b0556e2 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 18 Oct 2022 20:17:40 -0400 Subject: [PATCH 6/6] Fix typo --- src/systems/alias_elimination.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 59a3945607..3fb3fda6a1 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -366,7 +366,7 @@ function Base.in(i::Int, agk::AliasGraphKeySet) end canonicalize(a, b) = a <= b ? (a, b) : (b, a) -struct WeightedGraph{T, W} <: AbstractGraph{Int64} +struct WeightedGraph{T, W} <: AbstractGraph{T} graph::SimpleGraph{T} dict::Dict{Tuple{T, T}, W} end