Skip to content

Commit

Permalink
Merge db05a03 into 8a03d0e
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Oct 18, 2022
2 parents 8a03d0e + db05a03 commit 8903337
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 25 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down Expand Up @@ -73,7 +72,6 @@ Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.4.3, 0.5"
SciMLBase = "1.58.0"
Setfield = "0.7, 0.8, 1"
SimpleWeightedGraphs = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicUtils = "0.19"
Expand Down
80 changes: 57 additions & 23 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using SymbolicUtils: Rewriters
using SimpleWeightedGraphs
using Graphs.Experimental.Traversals

const KEEP = typemin(Int)
Expand Down Expand Up @@ -365,9 +364,33 @@ function Base.in(i::Int, agk::AliasGraphKeySet)
1 <= i <= length(aliasto) && aliasto[i] !== nothing
end

canonicalize(a, b) = a <= b ? (a, b) : (b, a)
struct WeightedGraph{T, W} <: AbstractGraph{Int64}
graph::SimpleGraph{T}
dict::Dict{Tuple{T, T}, W}
end
function WeightedGraph{T, W}(n) where {T, W}
WeightedGraph{T, W}(SimpleGraph{T}(n), Dict{Tuple{T, T}, W}())
end

function Graphs.add_edge!(g::WeightedGraph, u, v, w)
r = add_edge!(g.graph, u, v)
r && (g.dict[canonicalize(u, v)] = w)
r
end
Graphs.has_edge(g::WeightedGraph, u, v) = has_edge(g.graph, u, v)
Graphs.ne(g::WeightedGraph) = ne(g.graph)
Graphs.nv(g::WeightedGraph) = nv(g.graph)
get_weight(g::WeightedGraph, u, v) = g.dict[canonicalize(u, v)]
Graphs.is_directed(::Type{<:WeightedGraph}) = false
Graphs.inneighbors(g::WeightedGraph, v) = inneighbors(g.graph, v)
Graphs.outneighbors(g::WeightedGraph, v) = outneighbors(g.graph, v)
Graphs.vertices(g::WeightedGraph) = vertices(g.graph)
Graphs.edges(g::WeightedGraph) = vertices(g.graph)

function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph)
g = SimpleDiGraph{Int}(length(var_to_diff))
eqg = SimpleWeightedGraph{Int, Int}(length(var_to_diff))
eqg = WeightedGraph{Int, Int}(length(var_to_diff))
zero_vars = Int[]
for (v, (c, a)) in ag
if iszero(a)
Expand All @@ -378,7 +401,6 @@ function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph)
add_edge!(g, a, v)

add_edge!(eqg, v, a, c)
add_edge!(eqg, a, v, c)
end
transitiveclosure!(g)
weighted_transitiveclosure!(eqg)
Expand All @@ -394,9 +416,14 @@ end
function weighted_transitiveclosure!(g)
cps = connected_components(g)
for cp in cps
for k in cp, i in cp, j in cp
(has_edge(g, i, k) && has_edge(g, k, j)) || continue
add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j))
n = length(cp)
for k in cp
for i′ in 1:n, j′ in (i′ + 1):n
i = cp[i′]
j = cp[j′]
(has_edge(g, i, k) && has_edge(g, k, j)) || continue
add_edge!(g, i, j, get_weight(g, i, k) * get_weight(g, k, j))
end
end
end
return g
Expand Down Expand Up @@ -670,11 +697,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
end
diff_aliases = Vector{Pair{Int, Int}}[]
stem = Int[]
stems = Vector{Int}[]
stem_set = BitSet()
for (v, dv) in enumerate(var_to_diff)
processed[v] && continue
(dv === nothing && diff_to_var[v] === nothing) && continue
stem = Int[]
r = find_root!(dls, g, v)
prev_r = -1
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
Expand Down Expand Up @@ -714,9 +742,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
push!(stem_set, prev_r)
push!(stem, prev_r)
push!(diff_aliases, reach₌)
for (_, v) in reach₌
for (c, v) in reach₌
v == prev_r && continue
add_edge!(eqg, v, prev_r)
add_edge!(eqg, v, prev_r, c)
end
end

Expand All @@ -729,9 +757,24 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
dag[v] = c => a
end
end
# Obtain transitive closure after completing the alias edges from diff
# edges.
weighted_transitiveclosure!(eqg)
push!(stems, stem)

# clean up
for v in dls.visited
dls.dists[v] = typemax(Int)
processed[v] = true
end
empty!(dls.visited)
empty!(diff_aliases)
empty!(stem_set)
end

# Obtain transitive closure after completing the alias edges from diff
# edges. As a performance optimization, we only compute the transitive
# closure once at the very end.
weighted_transitiveclosure!(eqg)
zero_vars_set = BitSet()
for stem in stems
# Canonicalize by preferring the lower differentiated variable
# If we have the system
# ```
Expand Down Expand Up @@ -780,7 +823,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
# x := 0
# y := 0
# ```
zero_vars_set = BitSet()
for v in zero_vars
for a in Iterators.flatten((v, outneighbors(eqg, v)))
while true
Expand All @@ -803,17 +845,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
dag[v] = 0
end
end

# clean up
for v in dls.visited
dls.dists[v] = typemax(Int)
processed[v] = true
end
empty!(dls.visited)
empty!(diff_aliases)
empty!(stem)
empty!(stem_set)
empty!(zero_vars_set)
end

# update `dag`
for k in keys(dag)
dag[k]
Expand Down

0 comments on commit 8903337

Please sign in to comment.