Skip to content

Commit

Permalink
Merge pull request #1889 from SciML/myb/tearing_opt
Browse files Browse the repository at this point in the history
Make tearing only allocate O(N) memory and fix `count_nonzeros` perf problem
  • Loading branch information
YingboMa authored Oct 19, 2022
2 parents e15c650 + 9d1e3cb commit 9a44828
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
14 changes: 11 additions & 3 deletions src/structural_transformation/bipartite_tearing/modia_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
return ict
end

function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, vars,
function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars,
isder::F) where {F}
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir = :in)
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars, isder)
for var in vars
var_eq_matching[var] = ict.graph.matching[var]
Expand All @@ -76,6 +75,8 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
@unpack graph, solvable_graph = structure
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter, U))
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
vargraph = DiCMOBiGraph{true}(graph)
ict = IncrementalCycleTracker(vargraph; dir = :in)

ieqs = Int[]
filtered_vars = BitSet()
Expand All @@ -89,8 +90,15 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
end
var_eq_matching[var] = unassigned
end
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars,
tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, ieqs,
filtered_vars,
isder)

# clear cache
vargraph.ne = 0
for var in vars
vargraph.matching[var] = unassigned
end
empty!(ieqs)
empty!(filtered_vars)
end
Expand Down
2 changes: 1 addition & 1 deletion src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)

# N.B.: Ordinarily sparse vectors allow zero stored elements.
# Here we have a guarantee that they won't, so we can make this identification
count_nonzeros(a::SparseVector) = nnz(a)
count_nonzeros(a::CLILVector) = nnz(a)

# Linear variables are highest order differentiated variables that only appear
# in linear equations with only linear variables. Also, if a variable's any
Expand Down
1 change: 1 addition & 0 deletions src/systems/sparsematrixclil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function Base.view(a::SparseMatrixCLIL, i::Integer, ::Colon)
end
SparseArrays.nonzeroinds(a::CLILVector) = SparseArrays.nonzeroinds(a.vec)
SparseArrays.nonzeros(a::CLILVector) = SparseArrays.nonzeros(a.vec)
SparseArrays.nnz(a::CLILVector) = nnz(a.vec)

function Base.setindex!(S::SparseMatrixCLIL, v::CLILVector, i::Integer, c::Colon)
if v.vec.n != S.ncols
Expand Down

0 comments on commit 9a44828

Please sign in to comment.