Skip to content

Commit

Permalink
Only compute weighted_transitiveclosure! twice
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Oct 18, 2022
1 parent b6de5a1 commit db05a03
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -697,11 +697,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
end
diff_aliases = Vector{Pair{Int, Int}}[]
stem = Int[]
stems = Vector{Int}[]
stem_set = BitSet()
for (v, dv) in enumerate(var_to_diff)
processed[v] && continue
(dv === nothing && diff_to_var[v] === nothing) && continue
stem = Int[]
r = find_root!(dls, g, v)
prev_r = -1
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
Expand Down Expand Up @@ -756,9 +757,24 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
dag[v] = c => a
end
end
# Obtain transitive closure after completing the alias edges from diff
# edges.
weighted_transitiveclosure!(eqg)
push!(stems, stem)

# clean up
for v in dls.visited
dls.dists[v] = typemax(Int)
processed[v] = true
end
empty!(dls.visited)
empty!(diff_aliases)
empty!(stem_set)
end

# Obtain transitive closure after completing the alias edges from diff
# edges. As a performance optimization, we only compute the transitive
# closure once at the very end.
weighted_transitiveclosure!(eqg)
zero_vars_set = BitSet()
for stem in stems
# Canonicalize by preferring the lower differentiated variable
# If we have the system
# ```
Expand Down Expand Up @@ -807,7 +823,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
# x := 0
# y := 0
# ```
zero_vars_set = BitSet()
for v in zero_vars
for a in Iterators.flatten((v, outneighbors(eqg, v)))
while true
Expand All @@ -830,17 +845,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
dag[v] = 0
end
end

# clean up
for v in dls.visited
dls.dists[v] = typemax(Int)
processed[v] = true
end
empty!(dls.visited)
empty!(diff_aliases)
empty!(stem)
empty!(stem_set)
empty!(zero_vars_set)
end

# update `dag`
for k in keys(dag)
dag[k]
Expand Down

0 comments on commit db05a03

Please sign in to comment.