diff --git a/Project.toml b/Project.toml index 017e374e3d..66783ae237 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,6 @@ RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -73,7 +72,6 @@ Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.4.3, 0.5" SciMLBase = "1.58.0" Setfield = "0.7, 0.8, 1" -SimpleWeightedGraphs = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" SymbolicUtils = "0.19" diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index f20e4bf572..0c20c88dfc 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -1,5 +1,4 @@ using SymbolicUtils: Rewriters -using SimpleWeightedGraphs using Graphs.Experimental.Traversals const KEEP = typemin(Int) @@ -365,9 +364,33 @@ function Base.in(i::Int, agk::AliasGraphKeySet) 1 <= i <= length(aliasto) && aliasto[i] !== nothing end +canonicalize(a, b) = a <= b ? (a, b) : (b, a) +struct WeightedGraph{T, W} <: AbstractGraph{Int64} + graph::SimpleGraph{T} + dict::Dict{Tuple{T, T}, W} +end +function WeightedGraph{T, W}(n) where {T, W} + WeightedGraph{T, W}(SimpleGraph{T}(n), Dict{Tuple{T, T}, W}()) +end + +function Graphs.add_edge!(g::WeightedGraph, u, v, w) + r = add_edge!(g.graph, u, v) + r && (g.dict[canonicalize(u, v)] = w) + r +end +Graphs.has_edge(g::WeightedGraph, u, v) = has_edge(g.graph, u, v) +Graphs.ne(g::WeightedGraph) = ne(g.graph) +Graphs.nv(g::WeightedGraph) = nv(g.graph) +get_weight(g::WeightedGraph, u, v) = g.dict[canonicalize(u, v)] +Graphs.is_directed(::Type{<:WeightedGraph}) = false +Graphs.inneighbors(g::WeightedGraph, v) = inneighbors(g.graph, v) +Graphs.outneighbors(g::WeightedGraph, v) = outneighbors(g.graph, v) +Graphs.vertices(g::WeightedGraph) = vertices(g.graph) +Graphs.edges(g::WeightedGraph) = vertices(g.graph) + function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph) g = SimpleDiGraph{Int}(length(var_to_diff)) - eqg = SimpleWeightedGraph{Int, Int}(length(var_to_diff)) + eqg = WeightedGraph{Int, Int}(length(var_to_diff)) zero_vars = Int[] for (v, (c, a)) in ag if iszero(a) @@ -378,7 +401,6 @@ function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph) add_edge!(g, a, v) add_edge!(eqg, v, a, c) - add_edge!(eqg, a, v, c) end transitiveclosure!(g) weighted_transitiveclosure!(eqg) @@ -394,9 +416,14 @@ end function weighted_transitiveclosure!(g) cps = connected_components(g) for cp in cps - for k in cp, i in cp, j in cp - (has_edge(g, i, k) && has_edge(g, k, j)) || continue - add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j)) + n = length(cp) + for k in cp + for i′ in 1:n, j′ in (i′ + 1):n + i = cp[i′] + j = cp[j′] + (has_edge(g, i, k) && has_edge(g, k, j)) || continue + add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j)) + end end end return g @@ -670,11 +697,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) (v, w) -> var_to_diff[v] == w || var_to_diff[w] == v end diff_aliases = Vector{Pair{Int, Int}}[] - stem = Int[] + stems = Vector{Int}[] stem_set = BitSet() for (v, dv) in enumerate(var_to_diff) processed[v] && continue (dv === nothing && diff_to_var[v] === nothing) && continue + stem = Int[] r = find_root!(dls, g, v) prev_r = -1 for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop @@ -714,9 +742,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) push!(stem_set, prev_r) push!(stem, prev_r) push!(diff_aliases, reach₌) - for (_, v) in reach₌ + for (c, v) in reach₌ v == prev_r && continue - add_edge!(eqg, v, prev_r) + add_edge!(eqg, v, prev_r, c) end end @@ -729,9 +757,24 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) dag[v] = c => a end end - # Obtain transitive closure after completing the alias edges from diff - # edges. - weighted_transitiveclosure!(eqg) + push!(stems, stem) + + # clean up + for v in dls.visited + dls.dists[v] = typemax(Int) + processed[v] = true + end + empty!(dls.visited) + empty!(diff_aliases) + empty!(stem_set) + end + + # Obtain transitive closure after completing the alias edges from diff + # edges. As a performance optimization, we only compute the transitive + # closure once at the very end. + weighted_transitiveclosure!(eqg) + zero_vars_set = BitSet() + for stem in stems # Canonicalize by preferring the lower differentiated variable # If we have the system # ``` @@ -780,7 +823,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) # x := 0 # y := 0 # ``` - zero_vars_set = BitSet() for v in zero_vars for a in Iterators.flatten((v, outneighbors(eqg, v))) while true @@ -803,17 +845,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL) dag[v] = 0 end end - - # clean up - for v in dls.visited - dls.dists[v] = typemax(Int) - processed[v] = true - end - empty!(dls.visited) - empty!(diff_aliases) - empty!(stem) - empty!(stem_set) + empty!(zero_vars_set) end + # update `dag` for k in keys(dag) dag[k]